案例學習--Self-Attention及其實現實現


文章鏈接

texto alternativo

第0步. 什么是self-attention?

原文鏈接: Transformer 一篇就夠了(一): Self-attenstion

接下來,我們將要解釋和實現self-attention的全過程。

  • 准備輸入
  • 初始化參數
  • 獲取key,query和value
  • 給input1計算attention score
  • 計算softmax
  • 給value乘上score
  • 給value加權求和獲取output1
  • 重復步驟4-7,獲取output2,output3
import torch

第1步: 准備輸入

為了簡單起見,我們使用3個輸入,每個輸入都是一個4維的向量。

x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]
x = torch.tensor(x, dtype=torch.float32)
x
tensor([[1., 0., 1., 0.],
        [0., 2., 0., 2.],
        [1., 1., 1., 1.]])

第2步: 初始化參數

每一個輸入都有三個表示,分別為key(橙黃色)query(紅色)value(紫色)。比如說,每一個表示我們希望是一個3維的向量。由於輸入是4維,所以我們的參數矩陣為 4\times3 維。

后面我們會看到,value的維度,同樣也是我們輸出的維度。

為了能夠獲取這些表示,每一個輸入(綠色)要和key,query和value相乘,在我們例子中,我們使用如下的方式初始化這些參數。

w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)
Weights for key: 
 tensor([[0., 0., 1.],
        [1., 1., 0.],
        [0., 1., 0.],
        [1., 1., 0.]])
Weights for query: 
 tensor([[1., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 1.]])
Weights for value: 
 tensor([[0., 2., 0.],
        [0., 3., 0.],
        [1., 0., 3.],
        [1., 1., 0.]])

Note: 通常在神經網絡的初始化過程中,這些參數都是比較小的,一般會在Gaussian, Xavier and Kaiming distributions隨機采樣完成。

第3步:獲取key,query和value

現在我們有了三個參數,現在就讓我們來獲取實際上的key,query和value。

keys的表示為:

               [0, 0, 1]
[1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]

values的表示為:

               [0, 2, 0]
[1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3] 
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]

querys的表示為:

               [1, 0, 1]
[1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]

Notes: 在我們實際的應用中,有可能會在點乘后,加上一個bias的向量。

keys = x @ w_key
querys = x @ w_query
values = x @ w_value

print("Keys: \n", keys)
# tensor([[0., 1., 1.],
#         [4., 4., 0.],
#         [2., 3., 1.]])

print("Querys: \n", querys)
# tensor([[1., 0., 2.],
#         [2., 2., 2.],
#         [2., 1., 3.]])
print("Values: \n", values)
# tensor([[1., 2., 3.],
#         [2., 8., 0.],
#         [2., 6., 3.]])
Keys: 
 tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])
Querys: 
 tensor([[1., 0., 2.],
        [2., 2., 2.],
        [2., 1., 3.]])
Values: 
 tensor([[1., 2., 3.],
        [2., 8., 0.],
        [2., 6., 3.]])

第4步: 計算 attention scores

為了獲取input1的attention score,我們使用點乘來處理所有的key和query,包括它自己的key和value。這樣我們就能夠得到3個key的表示(因為我們有3個輸入),我們就獲得了3個attention score(藍色)。

            [0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
            [1, 0, 1]

這里我們需要注意一下,這里我們只有input1的例子。后面,我們會對其他的輸入的query做相同的操作。

attn_scores = querys @ keys.T
print(attn_scores)

# tensor([[ 2.,  4.,  4.],  # attention scores from Query 1
#         [ 4., 16., 12.],  # attention scores from Query 2
#         [ 4., 12., 10.]]) # attention scores from Query 3
tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])

第5步: 計算softmax

給attention score應用softmax。

softmax([2, 4, 4]) = [0.0, 0.5, 0.5]
from torch.nn.functional import softmax

attn_scores_softmax = softmax(attn_scores, dim=-1)
print(attn_scores_softmax)
# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
#         [6.0337e-06, 9.8201e-01, 1.7986e-02],
#         [2.9539e-04, 8.8054e-01, 1.1917e-01]])

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
print(attn_scores_softmax)
tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
        [6.0337e-06, 9.8201e-01, 1.7986e-02],
        [2.9539e-04, 8.8054e-01, 1.1917e-01]])
tensor([[0.0000, 0.5000, 0.5000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.9000, 0.1000]])

第6步: 給value乘上score

使用經過softmax后的attention score乘以它對應的value值(紫色),這樣我們就得到了3個weighted values(黃色)。

1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)
tensor([[[0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]],

        [[1.0000, 4.0000, 0.0000],
         [2.0000, 8.0000, 0.0000],
         [1.8000, 7.2000, 0.0000]],

        [[1.0000, 3.0000, 1.5000],
         [0.0000, 0.0000, 0.0000],
         [0.2000, 0.6000, 0.3000]]])

第7步: 給value加權求和獲取output

把所有的weighted values(黃色)進行element-wise的相加。

  [0.0, 0.0, 0.0]
+ [1.0, 4.0, 0.0]
+ [1.0, 3.0, 1.5]
-----------------
= [2.0, 7.0, 1.5]

得到結果向量[2.0, 7.0, 1.5](深綠色)就是ouput1的和其他key交互的query representation。

第8步: 重復步驟4-7,獲取output2,output3

現在,我們已經完成output1的全部計算,我們要對input2和input3也重復的完成步驟4~7的計算。

outputs = weighted_values.sum(dim=0)
print(outputs)

# tensor([[2.0000, 7.0000, 1.5000],  # Output 1
#         [2.0000, 8.0000, 0.0000],  # Output 2
#         [2.0000, 7.8000, 0.3000]]) # Output 3
tensor([[2.0000, 7.0000, 1.5000],
        [2.0000, 8.0000, 0.0000],
        [2.0000, 7.8000, 0.3000]])

福利: Tensorflow 2 實現

import tensorflow as tf
x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]

x = tf.convert_to_tensor(x, dtype=tf.float32)
print(x)
tf.Tensor(
[[1. 0. 1. 0.]
 [0. 2. 0. 2.]
 [1. 1. 1. 1.]], shape=(3, 4), dtype=float32)
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = tf.convert_to_tensor(w_key, dtype=tf.float32)
w_query = tf.convert_to_tensor(w_query, dtype=tf.float32)
w_value = tf.convert_to_tensor(w_value, dtype=tf.float32)
print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)

Weights for key: 
 tf.Tensor(
[[0. 0. 1.]
 [1. 1. 0.]
 [0. 1. 0.]
 [1. 1. 0.]], shape=(4, 3), dtype=float32)
Weights for query: 
 tf.Tensor(
[[1. 0. 1.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 1.]], shape=(4, 3), dtype=float32)
Weights for value: 
 tf.Tensor(
[[0. 2. 0.]
 [0. 3. 0.]
 [1. 0. 3.]
 [1. 1. 0.]], shape=(4, 3), dtype=float32)
keys = tf.matmul(x, w_key)
querys = tf.matmul(x, w_query)
values = tf.matmul(x, w_value)
print(keys)
print(querys)
print(values)
tf.Tensor(
[[0. 1. 1.]
 [4. 4. 0.]
 [2. 3. 1.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 0. 2.]
 [2. 2. 2.]
 [2. 1. 3.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 2. 3.]
 [2. 8. 0.]
 [2. 6. 3.]], shape=(3, 3), dtype=float32)
attn_scores = tf.matmul(querys, keys, transpose_b=True)
print(attn_scores)
tf.Tensor(
[[ 2.  4.  4.]
 [ 4. 16. 12.]
 [ 4. 12. 10.]], shape=(3, 3), dtype=float32)
attn_scores_softmax = tf.nn.softmax(attn_scores, axis=-1)
print(attn_scores_softmax)

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = tf.convert_to_tensor(attn_scores_softmax)
print(attn_scores_softmax)
tf.Tensor(
[[6.3378938e-02 4.6831051e-01 4.6831051e-01]
 [6.0336647e-06 9.8200780e-01 1.7986100e-02]
 [2.9538720e-04 8.8053685e-01 1.1916770e-01]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[0.  0.5 0.5]
 [0.  1.  0. ]
 [0.  0.9 0.1]], shape=(3, 3), dtype=float32)
weighted_values = values[:,None] * tf.transpose(attn_scores_softmax)[:,:,None]
print(weighted_values)
tf.Tensor(
[[[0.  0.  0. ]
  [0.  0.  0. ]
  [0.  0.  0. ]]

 [[1.  4.  0. ]
  [2.  8.  0. ]
  [1.8 7.2 0. ]]

 [[1.  3.  1.5]
  [0.  0.  0. ]
  [0.2 0.6 0.3]]], shape=(3, 3, 3), dtype=float32)
outputs = tf.reduce_sum(weighted_values, axis=0)  # 6
print(outputs)
tf.Tensor(
[[2.        7.        1.5      ]
 [2.        8.        0.       ]
 [2.        7.7999997 0.3      ]], shape=(3, 3), dtype=float32)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM