從2018年Google提出BERT模型開始,transformer結構就在NLP領域大殺四方,使用transformer的BERT模型在當時橫掃NLP領域的11項任務,取得SOTA成績,包括一直到后來相繼出現的XLNET,roBERT等,均采用transformer結構作為核心。在著名的SOTA機器翻譯排行榜上,幾乎所有排名靠前的模型都是用transformer。那么在transformer出現之前,占領市場的一直都是LSTM和GRU等模型,相比之下,transformer具有如下兩個顯著的優勢:
1.transformer能夠利用分布式GPU進行訓練,從而提升模型的訓練效率
2.在分析預測長序列文本時,transformer能夠捕捉間隔較長的語義關聯效果。
由於transformer在NLP領域的巨大成功,使得研究人員很自然的想到,如果將其應用於CV領域,又會取得怎樣的效果呢,畢竟CV領域中的模型長期以來都是被CNNs主導,如果transformer能在CV領域進行適配和創新,是否能為CV模型的發展開辟一條新的道路。果然,近期transformer又在CV領域殺瘋了,關於transformer的視覺模型在各大頂會論文中登場,其中又有不少模型實現了優於CNNs的效果。
今天我們就從最原始的transformer模型入手,來帶大家徹底認識一下transformer。
transformer的架構
transformer的總體架構如下圖:
從上圖可以看到,transformer的總體架構可以分為四個部分:輸入、輸出、編碼器和解碼器,以機器翻譯任務為例,各個部分的組成如下:
輸入部分(橙色區域)包含:
1.源文本的嵌入層以及位置編碼器
2.目標文本的嵌入層以及位置編碼器
輸出部分(藍色區域)包含:
1.線性層
2.softmax層
編碼器部分(紅色區域):
1.由N個編碼器層堆疊而成
2.每個編碼器層由兩個子層連接結構組成
3.第一個子層連接結構包括一個多頭自注意力層和規范化層以及一個殘差連接
4.第二個子層連接結構包括一個前饋全連接子層和規范化層以及一個殘差連接
解碼器部分(紫色區域):
1.由N個解碼器層堆疊而成
2.每個解碼器層由三個子層連接結構組成
3.第一個子層連接結構包括一個多頭自注意力子層和規范化層以及一個殘差連接
4.第二個子層連接結構包括一個多頭注意力子層和規范化層以及一個殘差連接
5.第三個子層連接結構包括一個前饋全連接子層和規范化層以及一個殘差連接
輸入部分:
文本嵌入層(Input Embedding)作用:無論是從源文本嵌入還是目標文本嵌入,都是為了將文本中的詞匯的數字表示轉變為向量表示,希望在這樣的高維空間捕捉詞匯間的關系。
Embedding代碼實現:
1 # 文本嵌入層 2 class Embedding(Layer): 3 4 ''' 5 :param vocab:詞表大小 6 :param dim_model:詞嵌入的維度 7 ''' 8 def __init__(self,vocab,dim_model,**kwargs): 9 self._vocab = vocab 10 self._dim_model = dim_model 11 super(Embedding, self).__init__(**kwargs) 12 13 def build(self, input_shape, **kwargs): 14 self.embeddings = self.add_weight( 15 shape=(self._vocab,self._dim_model), 16 initializer='global_uniform', 17 name='embeddings' 18 ) 19 super(Embedding, self).build(input_shape) 20 21 def call(self, inputs): 22 if K.dtype(inputs) != 'int32': 23 inputs = K.cast(inputs,'int32') 24 embeddings = K.gather(self.embeddings,inputs) 25 embeddings *= self._dim_model**0.5 26 return embeddings 27 28 def compute_output_shape(self, input_shape): 29 return input_shape + (self._dim_model)
位置編碼層(Position Encoding)作用:因為在transformer編碼器結構中並沒有針對詞匯位置信息的處理,因此需要在Embedding層后加入位置編碼器,將詞匯位置不同可能會產生不同語義的信息加入到詞嵌入張量中,以彌補位置信息的缺失。
PE計算公式:
PE(pos,2i)=sin(pos/100002i/dmodel)
PE(pos,2i+1)=cos(pos/100002i/dmodel)
Position Encoding代碼實現:
1 # 位置編碼層 2 class PositionEncoding(Layer): 3 4 ''' 5 :param dim_model:詞嵌入維度 6 ''' 7 def __init__(self,dim_model,**kwargs): 8 self._dim_model = dim_model 9 super(PositionEncoding, self).__init__(**kwargs) 10 11 def call(self, inputs, **kwargs): 12 seq_length = inputs.shape[1] 13 position_encodings = np.zeros((seq_length, self._model_dim)) 14 for pos in range(seq_length): 15 for i in range(self._model_dim): 16 position_encodings[pos, i] = pos / np.power(10000, (i - i % 2) / self._model_dim) 17 position_encodings[:, 0::2] = np.sin(position_encodings[:, 0::2]) # 2i 18 position_encodings[:, 1::2] = np.cos(position_encodings[:, 1::2]) # 2i+1 19 position_encodings = K.cast(position_encodings, 'float32') 20 return position_encodings 21 22 def compute_output_shape(self, input_shape): 23 return input_shape
Embedding和Position Encoding相加層代碼實現:
1 # Embeddings和Position Encodings相加層 2 class Add(Layer): 3 def __init__(self,**kwargs): 4 super(Add, self).__init__(**kwargs) 5 6 def call(self, inputs, **kwargs): 7 embeddings,positionEncodings = inputs 8 return embeddings + positionEncodings 9 10 def compute_output_shape(self, input_shape): 11 return input_shape[0]
編碼器解碼器組件實現
相關概念:
- 掩碼張量:掩代表遮掩,碼就是張量中的數值,它的尺寸不定,里面一般只有0 和 1 元素,代表位置被遮掩或者不被遮掩,因此它的作用就是讓另外一個張量中的一些數值被遮掩,也可以說是被替換,它的表現形式是一個張量。
- 掩碼張量的作用:在transformer中,掩碼張量的主要作用在應用attention,有一些生成的attention張量中的值計算有可能已知了未來信息而得到的,未來信息被看到是因為訓練時會把整個輸出結果都一次性進行Embedding,但是理論上解碼器的輸出卻不是一次就能產生最終結果的,而是一次次的通過上一次結果綜合得到的,因此,未來的信息可能被提前利用,這個時候就需要對未來信息進行遮掩。
- Multi-Head Attention 是由多個Self-Attention 組成。從多頭注意力的結構圖中,我們看到貌似這個所謂的多頭指的就是多組線性變變換層,其實並不是,這里其實僅使用了一組線性變換層,即三個變換張量對Q,K,V進行線性變換,這些變換並不會改變原有張量的尺度,因此每個變換張量都是方陣,得到結果后多頭作用才開始體現,每個頭從詞義層面分割輸出張量,但是句子中的每個詞的表示只取得一部分,也就是只分割了最后一維的詞嵌入向量(words embedding)。
- self-attention和multi-head attention的結構如下圖。在計算中需要用到矩陣Q(query),K(key),V(value),實際接收的輸入是單詞的表示向量組成的矩陣X或上一個編碼器的輸出,Q,K,V通過將輸入進行線性變換得到。
Self-Attention 層代碼實現:
1 # 自注意力層 2 class ScaledDotProductAttention(Layer): 3 def __init__(self, masking=True, future=False, dropout_rate=0., **kwargs): 4 self._masking = masking 5 self._future = future 6 self._dropout_rate = dropout_rate 7 self._masking_num = -2 ** 32 + 1 8 super(ScaledDotProductAttention, self).__init__(**kwargs) 9 10 def mask(self, inputs, masks): 11 masks = K.cast(masks, 'float32') 12 masks = K.tile(masks, [K.shape(inputs)[0] // K.shape(masks)[0], 1]) 13 masks = K.expand_dims(masks, 1) 14 outputs = inputs + masks * self._masking_num 15 return outputs 16 17 def future_mask(self, inputs): 18 diag_vals = tf.ones_like(inputs[0, :, :]) 19 tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() 20 future_masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) 21 paddings = tf.ones_like(future_masks) * self._masking_num 22 outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs) 23 return outputs 24 25 def call(self, inputs, **kwargs): 26 if self._masking: 27 assert len(inputs) == 4, "inputs should be set [queries, keys, values, masks]." 28 queries, keys, values, masks = inputs 29 else: 30 assert len(inputs) == 3, "inputs should be set [queries, keys, values]." 31 queries, keys, values = inputs 32 33 if K.dtype(queries) != 'float32': queries = K.cast(queries, 'float32') 34 if K.dtype(keys) != 'float32': keys = K.cast(keys, 'float32') 35 if K.dtype(values) != 'float32': values = K.cast(values, 'float32') 36 37 matmul = K.batch_dot(queries, tf.transpose(keys, [0, 2, 1])) # MatMul 38 scaled_matmul = matmul / int(queries.shape[-1]) ** 0.5 # Scale 39 if self._masking: 40 scaled_matmul = self.mask(scaled_matmul, masks) # Mask(opt.) 41 42 if self._future: 43 scaled_matmul = self.future_mask(scaled_matmul) 44 45 softmax_out = K.softmax(scaled_matmul) # SoftMax 46 # Dropout 47 out = K.dropout(softmax_out, self._dropout_rate) 48 49 outputs = K.batch_dot(out, values) 50 51 return outputs 52 53 def compute_output_shape(self, input_shape): 54 return input_shape
Multi-Head Attention層代碼實現:
1 # 多頭自注意力層 2 class MultiHeadAttention(Layer): 3 4 def __init__(self, n_heads, head_dim, dropout_rate=.1, masking=True, future=False, trainable=True, **kwargs): 5 self._n_heads = n_heads 6 self._head_dim = head_dim 7 self._dropout_rate = dropout_rate 8 self._masking = masking 9 self._future = future 10 self._trainable = trainable 11 super(MultiHeadAttention, self).__init__(**kwargs) 12 13 # 用方陣做Q,K,V的權重矩陣進行線性變換,維度不變 14 def build(self, input_shape): 15 self._weights_queries = self.add_weight( 16 shape=(input_shape[0][-1], self._n_heads * self._head_dim), 17 initializer='glorot_uniform', 18 trainable=self._trainable, 19 name='weights_queries') 20 self._weights_keys = self.add_weight( 21 shape=(input_shape[1][-1], self._n_heads * self._head_dim), 22 initializer='glorot_uniform', 23 trainable=self._trainable, 24 name='weights_keys') 25 self._weights_values = self.add_weight( 26 shape=(input_shape[2][-1], self._n_heads * self._head_dim), 27 initializer='glorot_uniform', 28 trainable=self._trainable, 29 name='weights_values') 30 super(MultiHeadAttention, self).build(input_shape) 31 32 def call(self, inputs, **kwargs): 33 if self._masking: 34 assert len(inputs) == 4, "inputs should be set [queries, keys, values, masks]." 35 queries, keys, values, masks = inputs 36 else: 37 assert len(inputs) == 3, "inputs should be set [queries, keys, values]." 38 queries, keys, values = inputs 39 40 queries_linear = K.dot(queries, self._weights_queries) 41 keys_linear = K.dot(keys, self._weights_keys) 42 values_linear = K.dot(values, self._weights_values) 43 44 # 將變換后的Q,K,V在embedding words的維度上進行切分 45 queries_multi_heads = tf.concat(tf.split(queries_linear, self._n_heads, axis=2), axis=0) 46 keys_multi_heads = tf.concat(tf.split(keys_linear, self._n_heads, axis=2), axis=0) 47 values_multi_heads = tf.concat(tf.split(values_linear, self._n_heads, axis=2), axis=0) 48 49 if self._masking: 50 att_inputs = [queries_multi_heads, keys_multi_heads, values_multi_heads, masks] 51 else: 52 att_inputs = [queries_multi_heads, keys_multi_heads, values_multi_heads] 53 54 attention = ScaledDotProductAttention( 55 masking=self._masking, future=self._future, dropout_rate=self._dropout_rate) 56 att_out = attention(att_inputs) 57 58 outputs = tf.concat(tf.split(att_out, self._n_heads, axis=0), axis=2) 59 60 return outputs 61 62 def compute_output_shape(self, input_shape): 63 return input_shape
Position-wise Feed Forward代碼實現:
1 # Position-wise Feed Forward層 2 # out = (relu(xW1+b1))W2+b2 3 class PositionWiseFeedForward(Layer): 4 5 def __init__(self, model_dim, inner_dim, trainable=True, **kwargs): 6 self._model_dim = model_dim 7 self._inner_dim = inner_dim 8 self._trainable = trainable 9 super(PositionWiseFeedForward, self).__init__(**kwargs) 10 11 def build(self, input_shape): 12 self.weights_inner = self.add_weight( 13 shape=(input_shape[-1], self._inner_dim), 14 initializer='glorot_uniform', 15 trainable=self._trainable, 16 name="weights_inner") 17 self.weights_out = self.add_weight( 18 shape=(self._inner_dim, self._model_dim), 19 initializer='glorot_uniform', 20 trainable=self._trainable, 21 name="weights_out") 22 self.bais_inner = self.add_weight( 23 shape=(self._inner_dim,), 24 initializer='uniform', 25 trainable=self._trainable, 26 name="bais_inner") 27 self.bais_out = self.add_weight( 28 shape=(self._model_dim,), 29 initializer='uniform', 30 trainable=self._trainable, 31 name="bais_out") 32 super(PositionWiseFeedForward, self).build(input_shape) 33 34 def call(self, inputs, **kwargs): 35 if K.dtype(inputs) != 'float32': 36 inputs = K.cast(inputs, 'float32') 37 inner_out = K.relu(K.dot(inputs, self.weights_inner) + self.bais_inner) 38 outputs = K.dot(inner_out, self.weights_out) + self.bais_out 39 return outputs 40 41 def compute_output_shape(self, input_shape): 42 return self._model_dim
Normalization代碼實現:
1 # Normalization層 2 class LayerNormalization(Layer): 3 4 def __init__(self, epsilon=1e-8, **kwargs): 5 self._epsilon = epsilon 6 super(LayerNormalization, self).__init__(**kwargs) 7 8 def build(self, input_shape): 9 self.beta = self.add_weight( 10 shape=(input_shape[-1],), 11 initializer='zero', 12 name='beta') 13 self.gamma = self.add_weight( 14 shape=(input_shape[-1],), 15 initializer='one', 16 name='gamma') 17 super(LayerNormalization, self).build(input_shape) 18 19 def call(self, inputs, **kwargs): 20 mean, variance = tf.nn.moments(inputs, [-1], keepdims=True) 21 normalized = (inputs - mean) / ((variance + self._epsilon) ** 0.5) 22 outputs = self.gamma * normalized + self.beta 23 return outputs 24 25 def compute_output_shape(self, input_shape): 26 return input_shape
Transformer整體架構實現:
1 class Transformer(Layer): 2 def __init__(self, vocab_size, model_dim, n_heads=8, encoder_stack=6, decoder_stack=6, feed_forward_size=2048, dropout=0.1, **kwargs): 3 self._vocab_size = vocab_size 4 self._model_dim = model_dim 5 self._n_heads = n_heads 6 self._encoder_stack = encoder_stack 7 self._decoder_stack = decoder_stack 8 self._feed_forward_size = feed_forward_size 9 self._dropout_rate = dropout 10 super(Transformer, self).__init__(**kwargs) 11 12 def build(self, input_shape): 13 self.embeddings = self.add_weight( 14 shape=(self._vocab_size, self._model_dim), 15 initializer='glorot_uniform', 16 trainable=True, 17 name="embeddings") 18 super(Transformer, self).build(input_shape) 19 20 def encoder(self,inputs): 21 if K.dtype(inputs) != 'int32': 22 inputs = K.cast(inputs, 'int32') 23 24 masks = K.equal(inputs,0) 25 # Embeddings 26 embeddings = Embedding(self._vocab_size,self._model_dim)(inputs) 27 # Position Encodings 28 position_encodings = PositionEncoding(self._model_dim)(embeddings) 29 # Embeddings + Position Encodings 30 encodings = embeddings + position_encodings 31 # Dropout 32 encodings = K.dropout(encodings,self._dropout_rate) 33 34 # Encoder 35 for i in range(self._encoder_stack): 36 # Multi-head Attention 37 attention = MultiHeadAttention(self._n_heads,self._model_dim // self._n_heads) 38 attention_input = [encodings,encodings,encodings,masks] 39 attention_out = attention(attention_input) 40 # Add & Norm 41 attention_out += encodings 42 attention_out = LayerNormalization()(attention_out) 43 # Feed-Forward 44 pwff = PositionWiseFeedForward(self._model_dim,self._feed_forward_size) 45 pwff_out = pwff(attention_out) 46 # Add & Norm 47 pwff_out += attention_out 48 encodings = LayerNormalization()(pwff_out) 49 50 return encodings,masks 51 52 def decoder(self,inputs): 53 decoder_inputs, encoder_encodings, encoder_masks = inputs 54 if K.dtype(decoder_inputs) != 'int32': 55 decoder_inputs = K.cast(decoder_inputs, 'int32') 56 decoder_masks = K.equal(decoder_inputs,0) 57 # Embeddings 58 embeddings = Embedding(self._vocab_size,self._model_dim)(decoder_inputs) 59 # Position Encodings 60 position_encodings = PositionEncoding(self._model_dim)(embeddings) 61 # Embeddings + Position Encodings 62 encodings = embeddings + position_encodings 63 # Dropout 64 encodings = K.dropout(encodings,self._dropout_rate) 65 66 for i in range(self._decoder_stack): 67 # Masked-Multi-head-Attention 68 masked_attention = MultiHeadAttention(self._n_heads, self._model_dim // self._n_heads, future=True) 69 masked_attention_input = [encodings, encodings, encodings, decoder_masks] 70 masked_attention_out = masked_attention(masked_attention_input) 71 # Add & Norm 72 masked_attention_out += encodings 73 masked_attention_out = LayerNormalization()(masked_attention_out) 74 75 # Multi-head-Attention 76 attention = MultiHeadAttention(self._n_heads, self._model_dim // self._n_heads) 77 attention_input = [masked_attention_out, encoder_encodings, encoder_encodings, encoder_masks] 78 attention_out = attention(attention_input) 79 # Add & Norm 80 attention_out += masked_attention_out 81 attention_out = LayerNormalization()(attention_out) 82 83 # Feed-Forward 84 pwff = PositionWiseFeedForward(self._model_dim, self._feed_forward_size) 85 pwff_out = pwff(attention_out) 86 # Add & Norm 87 pwff_out += attention_out 88 encodings = LayerNormalization()(pwff_out) 89 90 # Pre-Softmax 與 Embeddings 共享參數 91 linear_projection = K.dot(encodings, K.transpose(self.embeddings)) 92 outputs = K.softmax(linear_projection) 93 return outputs 94 95 def call(self, inputs, **kwargs): 96 encoder_inputs, decoder_inputs = inputs 97 encoder_encodings, encoder_masks = self.encoder(encoder_inputs) 98 encoder_outputs = self.decoder([decoder_inputs, encoder_encodings, encoder_masks]) 99 return encoder_outputs 100 101 def compute_output_shape(self, input_shape): 102 return (input_shape[0][0], input_shape[0][1], self._vocab_size)
下一篇將使用transformer構建BERT網絡進行文本情感分類實戰。