最近一直在看有關transformer相關網絡結構,為此我特意將經典結構 Attention is all you need 論文進行了解讀,並根據其源碼深入解讀attntion經典結構,
為此本博客將介紹如下內容:
論文鏈接:https://arxiv.org/abs/1706.03762
一.Transformer結構與原理解釋。
第一部分介紹Attention is all you need 結構、模塊、公式。暫時不介紹什么Q K V 什么Attention 什么編解碼等,單我將會根據代碼解讀介紹,讓讀者更容易理解。
①結構: Transformer由且僅由self.Attention和Feed Forward Neural Network組成,即mutil-head-attention與FFN,如下圖。
②模塊結構:除了以上提到mutil-head-attention與FFN外,還需有個位置編碼結構positional encoding以及mask編碼模塊。
③公式:
位置編碼公式(還有很多其它公式,該論文使用此公式)
Q K V公式
FFN基本是由nn.Linear線性和激活變化,在后面用代碼講解。
二.代碼解讀。
第二部分會從模型輸入開始,層層遞推介紹整個編碼和解碼過程、以及整個過程中使用的Attention編碼、FFN編碼、位置編碼等。
ENCODE模塊:
① 編碼輸入數據介紹:
enc_input = [
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3]]
編碼使用輸入數據,為4x6行,表示4個句子,每個句子有6個單詞,包含標點符號。
② 輸入值的Embedding與位置編碼
輸入值embedding:
self.src_emb = nn.Embedding(vocab_size, d_model) # d_model=128
vocab_size:詞典的大小尺寸,比如總共出現5000個詞,那就輸入5000。此時index為(0-4999)d_model:嵌入向量的維度,即用多少維來表示一個詞或符號
隨后可將輸入x=enc_input,可將enc_outputs則表示嵌入成功,維度為[4,6,128]分別表示batch為4,詞為6,用128維度描述詞6
x = self.src_emb(x) # 詞嵌入
位置編碼:
以下使用位置編碼公式的代碼,為此無需再介紹了。
1 pe = torch.zeros(max_len, d_model) 2 position = torch.arange(0., max_len).unsqueeze(1) 3 div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) # 偶數列 4 pe[:, 0::2] = torch.sin(position * div_term) # 奇數列 5 pe[:, 1::2] = torch.cos(position * div_term) 6 pe = pe.unsqueeze(0)
將編碼進行位置編碼后,位置為[1,6,128]+輸入編碼的[4,6,128],相當於句子已經結合了位置編碼信息,作為新新的輸入。
x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) # torch.autograd.Variable 表示有梯度的張量變量
③self.attention的編碼:
在介紹此之前,先普及一個知識,若X與Y相等,則為self attention 否則為cross-attention,因為解碼時候X!=Y.
獲取Q K V 代碼,實際是一個線性變化,將以上輸入x變成[4,6,512],然后通過head個數8與對應dv,dk將512拆分[8,64],隨后移維度位置,變成[4,8,6,64]
1 self.WQ = nn.Linear(d_model, d_k * n_heads) # 利用線性卷積 2 self.WK = nn.Linear(d_model, d_k * n_heads) 3 self.WV = nn.Linear(d_model, d_v * n_heads)
變化后的q k v
1 q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # 線性卷積后再分組實現head功能 2 k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) 3 v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) 4 attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # 編導對應的頭
隨后通過以上self公式,將其編碼計算
1 scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) 5 attn = nn.Softmax(dim=-1)(scores) 6 context = torch.matmul(attn, V)
以上編碼將是encode編碼得到結果,我們將得到結果進行還原:
1context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # 將其還原 2output = self.linear(context) # 通過線性又將其變成原來模樣維度 3layer_norm(output + Q) # 這里加Q 實際是對Q尋找
以上將重新得到新的輸入x,維度為[4,6,128]
④ FFN編碼:
將以上的輸出維度為[4,6,128]進行FNN層變化,實際類似線性殘差網絡變化,得到最終輸出
1 class PoswiseFeedForwardNet(nn.Module): 2 3 def __init__(self, d_model, d_ff): 4 super(PoswiseFeedForwardNet, self).__init__() 5 self.l1 = nn.Linear(d_model, d_ff) 6 self.l2 = nn.Linear(d_ff, d_model) 7 8 self.relu = GELU() 9 self.layer_norm = nn.LayerNorm(d_model) 10 11 def forward(self, inputs): 12 residual = inputs 13 output = self.l1(inputs) # 一層線性卷積 14 output = self.relu(output) 15 output = self.l2(output) # 一層線性卷積 16 return self.layer_norm(output + residual)
⑤ 重復以上步驟編碼,即將得到經過FFN變化的輸出x,維度為[4,6,128],將其重復步驟③-④,因其編碼為6個,可重復5個便是完成相應的編碼模塊。
DECODE模塊:
①解碼輸入數據介紹,包含以下數據輸入(dec_input)、enc_input的輸入與解碼后輸出的數據,維度為[4,6,128]:
dec_input = [
[1, 0, 0, 0, 0, 0],
[1, 3, 0, 0, 0, 0],
[1, 3, 4, 0, 0, 0],
[1, 3, 4, 1, 0, 0]]
②dec_input的Embedding與位置編碼
因其與encode的實現方法一致,只需將enc_input使用dec_input取代,得到dec_outputs,因此這里將不在介紹。
③mask編碼,包含整體編碼與局部編碼
整體編碼,代碼如下:
1 def get_attn_pad_mask(seq_q, seq_k, pad_index): 2 batch_size, len_q = seq_q.size() 3 batch_size, len_k = seq_k.size() 4 pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1) 5 pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int) 6 return pad_attn_mask.expand(batch_size, len_q, len_k)
以上代碼實際是將dec_input進行處理,實際變成以下數據:
[[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1]]
將其增添維度為[4,1,6],並將其擴張為[4,6,6]
局部代碼編寫,實際為上三角矩陣:
[[0. 1. 1. 1. 1. 1.]
[0. 0. 1. 1. 1. 1.]
[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
將以上數據添加維度為[1,6,6],在將擴展變成[4,6,6]
關於整體mask與局部mask編碼,我的理解是整體信息為語句4個詞6個,根據解碼輸入編碼整體信息,而局部編碼是基於一個語句6*6編碼信息,將其擴張重復到4個語句,
使其mask獲得整體信息與局部信息。
1 dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index) # 整體編碼的mask 2 dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) 3 dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) # torch.gt(a,b) a>b 則為1否則為0 4 dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)
最終將mask整合,獲取dec_self_attn_mask信息,同理dec_enc_attn_mask(維度為解碼編碼詞維度)采用dec_self_attn_mask的第一步便可獲取。
④編碼輸入self-Attention,包含2部分
解碼輸入dec_outputs進行self.Attention:
實際使用以上Q K V公式,具體實現和編碼實現方法一致,唯一不同是
在Q*KT會使用解碼maskdec_self_attn_mask,其重要代碼為scores.masked_fill_(attn_mask, -1e9),其它代碼為:
1 class ScaledDotProductAttention(nn.Module): 2 3 def __init__(self, d_k, device): 4 super(ScaledDotProductAttention, self).__init__() 5 self.device = device 6 self.d_k = d_k 7 8 def forward(self, Q, K, V, attn_mask): 9 scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) 10 attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool) 11 attn_mask = attn_mask.to(self.device) 12 scores.masked_fill_(attn_mask, -1e9) # it is true give -1e9 13 attn = nn.Softmax(dim=-1)(scores) 14 context = torch.matmul(attn, V) 15 return context, attn
以上代碼將執行以下代碼:
context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
attn_mask=attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # 將其還原
output = self.linear(context) # 通過線性又將其變成原來模樣維度
dec_outputs = self.layer_norm(output + Q) # 這里加Q 實際是對Q尋找
到此為止已經完成了解碼輸入的self-attention模塊,輸出為dec_outputs實際除了增加mask編碼調整Q*KT以外,其它完全相同。
編碼輸出dec_outputs進行Cross Attention:
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) # 重點說明enc_outputs來源編碼結果,是一直不變的
以上為Cross Attention 過程,以上代碼除了Q來源dec_outputs,K V 來源編碼輸出enc_outputs以外,即論文所說X與Y不等得到的Q K V稱為Cross Attention。
實際以上代碼與執行解碼self-Attention方法完全一致,僅僅mask更改上文提供的方法,得到輸出結果為dec_outputs,因此這里將不在解釋了。
⑤ FFN編碼。
通過④的attention編碼,得到dec_outputs后,采用編碼步驟④的FNN方法。
⑥ 重復步驟④-⑤多次,便實現了解碼過程。
至此,本文已完全解讀完Attention is all you need的編碼與解碼結構。
個人重點總結:
①未使用通常kernel=3的CNN卷積,而所有均使用Linear卷積;
②編碼傳遞K V 解碼傳遞Q;
③self-attention 和 cross attention本質是X與Y值不同,即得到Q 和 K V 數據來源不同,但實現方法一致;
④ transformer重點模塊為attention(一般是mutil-head attention)、FFN、位置編碼、mask編碼;
最后貼上完整代碼,便於讀者深入理解:
整體代碼:
1 import json 2 import math 3 import torch 4 import torchvision 5 import torch.nn as nn 6 import numpy as np 7 from pdb import set_trace 8 9 from torch.autograd import Variable 10 11 12 def get_attn_pad_mask(seq_q, seq_k, pad_index): 13 batch_size, len_q = seq_q.size() 14 batch_size, len_k = seq_k.size() 15 pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1) 16 pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int) 17 return pad_attn_mask.expand(batch_size, len_q, len_k) 18 19 20 def get_attn_subsequent_mask(seq): 21 attn_shape = [seq.size(0), seq.size(1), seq.size(1)] 22 subsequent_mask = np.triu(np.ones(attn_shape), k=1) 23 subsequent_mask = torch.from_numpy(subsequent_mask).int() 24 return subsequent_mask 25 26 27 class GELU(nn.Module): 28 29 def forward(self, x): 30 return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 31 32 33 class PositionalEncoding(nn.Module): 34 "Implement the PE function." 35 36 def __init__(self, d_model, dropout, max_len=5000): # 37 super(PositionalEncoding, self).__init__() 38 self.dropout = nn.Dropout(p=dropout) 39 40 # Compute the positional encodings once in log space. 41 pe = torch.zeros(max_len, d_model) 42 position = torch.arange(0., max_len).unsqueeze(1) 43 div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) # 偶數列 44 pe[:, 0::2] = torch.sin(position * div_term) 45 pe[:, 1::2] = torch.cos(position * div_term) 46 pe = pe.unsqueeze(0) 47 self.register_buffer('pe', pe) # 將變量pe保存到內存中,不計算梯度 48 49 def forward(self, x): 50 x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) # torch.autograd.Variable 表示有梯度的張量變量 51 return self.dropout(x) 52 53 54 class ScaledDotProductAttention(nn.Module): 55 56 def __init__(self, d_k, device): 57 super(ScaledDotProductAttention, self).__init__() 58 self.device = device 59 self.d_k = d_k 60 61 def forward(self, Q, K, V, attn_mask): 62 scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) 63 attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool) 64 attn_mask = attn_mask.to(self.device) 65 scores.masked_fill_(attn_mask, -1e9) # it is true give -1e9 66 attn = nn.Softmax(dim=-1)(scores) 67 context = torch.matmul(attn, V) 68 return context, attn 69 70 71 class MultiHeadAttention(nn.Module): 72 73 def __init__(self, d_model, d_k, d_v, n_heads, device): 74 super(MultiHeadAttention, self).__init__() 75 self.WQ = nn.Linear(d_model, d_k * n_heads) # 利用線性卷積 76 self.WK = nn.Linear(d_model, d_k * n_heads) 77 self.WV = nn.Linear(d_model, d_v * n_heads) 78 79 self.linear = nn.Linear(n_heads * d_v, d_model) 80 81 self.layer_norm = nn.LayerNorm(d_model) 82 self.device = device 83 84 self.d_model = d_model 85 self.d_k = d_k 86 self.d_v = d_v 87 self.n_heads = n_heads 88 89 def forward(self, Q, K, V, attn_mask): 90 batch_size = Q.shape[0] 91 q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # 線性卷積后再分組實現head功能 92 k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) 93 v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2) 94 95 attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # 編導對應的頭 96 context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s, 97 attn_mask=attn_mask) 98 context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # 將其還原 99 output = self.linear(context) # 通過線性又將其變成原來模樣維度 100 return self.layer_norm(output + Q), attn # 這里加Q 實際是對Q尋找 101 102 103 class PoswiseFeedForwardNet(nn.Module): 104 105 def __init__(self, d_model, d_ff): 106 super(PoswiseFeedForwardNet, self).__init__() 107 self.l1 = nn.Linear(d_model, d_ff) 108 self.l2 = nn.Linear(d_ff, d_model) 109 110 self.relu = GELU() 111 self.layer_norm = nn.LayerNorm(d_model) 112 113 def forward(self, inputs): 114 residual = inputs 115 output = self.l1(inputs) # 一層線性卷積 116 output = self.relu(output) 117 output = self.l2(output) # 一層線性卷積 118 return self.layer_norm(output + residual) 119 120 121 class EncoderLayer(nn.Module): 122 123 def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device): 124 super(EncoderLayer, self).__init__() 125 self.enc_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device) 126 self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff) 127 128 def forward(self, enc_inputs, enc_self_attn_mask): 129 enc_outputs, attn = self.enc_self_attn(Q=enc_inputs, K=enc_inputs, V=enc_inputs, attn_mask=enc_self_attn_mask) 130 # X=Y 因此Q K V相等 131 enc_outputs = self.pos_ffn(enc_outputs) # 132 return enc_outputs, attn 133 134 135 class Encoder(nn.Module): 136 137 def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device): 138 # 4 128 256 64 64 8 4 0 139 super(Encoder, self).__init__() 140 self.device = device 141 self.pad_index = pad_index 142 self.src_emb = nn.Embedding(vocab_size, d_model) 143 # vocab_size:詞典的大小尺寸,比如總共出現5000個詞,那就輸入5000。此時index為(0-4999) d_model:嵌入向量的維度,即用多少維來表示一個符號 144 self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0) 145 146 self.layers = [] 147 for _ in range(n_layers): 148 encoder_layer = EncoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device) 149 self.layers.append(encoder_layer) 150 self.layers = nn.ModuleList(self.layers) 151 152 def forward(self, x): 153 enc_outputs = self.src_emb(x) # 詞嵌入 154 enc_outputs = self.pos_emb(enc_outputs) # pos+matx 155 enc_self_attn_mask = get_attn_pad_mask(x, x, self.pad_index) 156 157 enc_self_attns = [] 158 for layer in self.layers: 159 enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask) 160 enc_self_attns.append(enc_self_attn) 161 162 enc_self_attns = torch.stack(enc_self_attns) 163 enc_self_attns = enc_self_attns.permute([1, 0, 2, 3, 4]) 164 return enc_outputs, enc_self_attns 165 166 167 class DecoderLayer(nn.Module): 168 169 def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device): 170 super(DecoderLayer, self).__init__() 171 self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device) 172 self.dec_enc_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device) 173 self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff) 174 175 def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): 176 dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) 177 dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) 178 dec_outputs = self.pos_ffn(dec_outputs) 179 return dec_outputs, dec_self_attn, dec_enc_attn 180 181 182 class Decoder(nn.Module): 183 184 def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device): 185 super(Decoder, self).__init__() 186 self.pad_index = pad_index 187 self.device = device 188 self.tgt_emb = nn.Embedding(vocab_size, d_model) 189 self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0) 190 self.layers = [] 191 for _ in range(n_layers): 192 decoder_layer = DecoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device) 193 self.layers.append(decoder_layer) 194 self.layers = nn.ModuleList(self.layers) 195 196 def forward(self, dec_inputs, enc_inputs, enc_outputs): 197 dec_outputs = self.tgt_emb(dec_inputs) 198 dec_outputs = self.pos_emb(dec_outputs) 199 200 dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index) 201 dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) 202 dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) 203 dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index) 204 205 dec_self_attns, dec_enc_attns = [], [] 206 for layer in self.layers: 207 dec_outputs, dec_self_attn, dec_enc_attn = layer( 208 dec_inputs=dec_outputs, 209 enc_outputs=enc_outputs, 210 dec_self_attn_mask=dec_self_attn_mask, 211 dec_enc_attn_mask=dec_enc_attn_mask) 212 dec_self_attns.append(dec_self_attn) 213 dec_enc_attns.append(dec_enc_attn) 214 dec_self_attns = torch.stack(dec_self_attns) 215 dec_enc_attns = torch.stack(dec_enc_attns) 216 217 dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4]) 218 dec_enc_attns = dec_enc_attns.permute([1, 0, 2, 3, 4]) 219 220 return dec_outputs, dec_self_attns, dec_enc_attns 221 222 223 class MaskedDecoderLayer(nn.Module): 224 225 def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device): 226 super(MaskedDecoderLayer, self).__init__() 227 self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device) 228 self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff) 229 230 def forward(self, dec_inputs, dec_self_attn_mask): 231 dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) 232 dec_outputs = self.pos_ffn(dec_outputs) 233 return dec_outputs, dec_self_attn 234 235 236 class MaskedDecoder(nn.Module): 237 238 def __init__(self, vocab_size, d_model, d_ff, d_k, 239 d_v, n_heads, n_layers, pad_index, device): 240 super(MaskedDecoder, self).__init__() 241 self.pad_index = pad_index 242 self.tgt_emb = nn.Embedding(vocab_size, d_model) 243 self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0) 244 245 self.layers = [] 246 for _ in range(n_layers): 247 decoder_layer = MaskedDecoderLayer( 248 d_model=d_model, d_ff=d_ff, 249 d_k=d_k, d_v=d_v, n_heads=n_heads, 250 device=device) 251 self.layers.append(decoder_layer) 252 self.layers = nn.ModuleList(self.layers) 253 254 def forward(self, dec_inputs): 255 dec_outputs = self.tgt_emb(dec_inputs) 256 dec_outputs = self.pos_emb(dec_outputs) 257 258 dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index) 259 dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) 260 dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) 261 dec_self_attns = [] 262 for layer in self.layers: 263 dec_outputs, dec_self_attn = layer( 264 dec_inputs=dec_outputs, 265 dec_self_attn_mask=dec_self_attn_mask) 266 dec_self_attns.append(dec_self_attn) 267 dec_self_attns = torch.stack(dec_self_attns) 268 dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4]) 269 return dec_outputs, dec_self_attns 270 271 272 class BertModel(nn.Module): 273 274 def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device): 275 super(BertModel, self).__init__() 276 self.tok_embed = nn.Embedding(vocab_size, d_model) 277 self.pos_embed = PositionalEncoding(d_model=d_model, dropout=0) 278 self.seg_embed = nn.Embedding(2, d_model) 279 280 self.layers = [] 281 for _ in range(n_layers): 282 encoder_layer = EncoderLayer( 283 d_model=d_model, d_ff=d_ff, 284 d_k=d_k, d_v=d_v, n_heads=n_heads, 285 device=device) 286 self.layers.append(encoder_layer) 287 self.layers = nn.ModuleList(self.layers) 288 289 self.pad_index = pad_index 290 291 self.fc = nn.Linear(d_model, d_model) 292 self.active1 = nn.Tanh() 293 self.classifier = nn.Linear(d_model, 2) 294 295 self.linear = nn.Linear(d_model, d_model) 296 self.active2 = GELU() 297 self.norm = nn.LayerNorm(d_model) 298 299 self.decoder = nn.Linear(d_model, vocab_size, bias=False) 300 self.decoder.weight = self.tok_embed.weight 301 self.decoder_bias = nn.Parameter(torch.zeros(vocab_size)) 302 303 def forward(self, input_ids, segment_ids, masked_pos): 304 output = self.tok_embed(input_ids) + self.seg_embed(segment_ids) 305 output = self.pos_embed(output) 306 enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.pad_index) 307 308 for layer in self.layers: 309 output, enc_self_attn = layer(output, enc_self_attn_mask) 310 311 h_pooled = self.active1(self.fc(output[:, 0])) 312 logits_clsf = self.classifier(h_pooled) 313 314 masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) 315 h_masked = torch.gather(output, 1, masked_pos) 316 h_masked = self.norm(self.active2(self.linear(h_masked))) 317 logits_lm = self.decoder(h_masked) + self.decoder_bias 318 319 return logits_lm, logits_clsf, output 320 321 322 class GPTModel(nn.Module): 323 324 def __init__(self, vocab_size, d_model, d_ff, 325 d_k, d_v, n_heads, n_layers, pad_index, 326 device): 327 super(GPTModel, self).__init__() 328 self.decoder = MaskedDecoder( 329 vocab_size=vocab_size, 330 d_model=d_model, d_ff=d_ff, 331 d_k=d_k, d_v=d_v, n_heads=n_heads, 332 n_layers=n_layers, pad_index=pad_index, 333 device=device) 334 self.projection = nn.Linear(d_model, vocab_size, bias=False) 335 336 def forward(self, dec_inputs): 337 dec_outputs, dec_self_attns = self.decoder(dec_inputs) 338 dec_logits = self.projection(dec_outputs) 339 return dec_logits, dec_self_attns 340 341 342 class Classifier(nn.Module): 343 344 def __init__(self, vocab_size, d_model, d_ff, 345 d_k, d_v, n_heads, n_layers, 346 pad_index, device, num_classes): 347 super(Classifier, self).__init__() 348 self.encoder = Encoder( 349 vocab_size=vocab_size, 350 d_model=d_model, d_ff=d_ff, 351 d_k=d_k, d_v=d_v, n_heads=n_heads, 352 n_layers=n_layers, pad_index=pad_index, 353 device=device) 354 self.projection = nn.Linear(d_model, num_classes) 355 356 def forward(self, enc_inputs): 357 enc_outputs, enc_self_attns = self.encoder(enc_inputs) 358 mean_enc_outputs = torch.mean(enc_outputs, dim=1) 359 logits = self.projection(mean_enc_outputs) 360 return logits, enc_self_attns 361 362 363 class Translation(nn.Module): 364 365 def __init__(self, src_vocab_size, tgt_vocab_size, d_model, 366 d_ff, d_k, d_v, n_heads, n_layers, src_pad_index, 367 tgt_pad_index, device): 368 super(Translation, self).__init__() 369 self.encoder = Encoder( 370 vocab_size=src_vocab_size, # 5 371 d_model=d_model, d_ff=d_ff, # 128 256 372 d_k=d_k, d_v=d_v, n_heads=n_heads, # 64 64 8 373 n_layers=n_layers, pad_index=src_pad_index, # 4 0 374 device=device) 375 self.decoder = Decoder( 376 vocab_size=tgt_vocab_size, # 5 377 d_model=d_model, d_ff=d_ff, # 128 256 378 d_k=d_k, d_v=d_v, n_heads=n_heads, # 64 64 8 379 n_layers=n_layers, pad_index=tgt_pad_index, # 4 0 380 device=device) 381 self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False) 382 383 # def forward(self, enc_inputs, dec_inputs, decode_lengths): 384 # enc_outputs, enc_self_attns = self.encoder(enc_inputs) 385 # dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) 386 # dec_logits = self.projection(dec_outputs) 387 # return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns, decode_lengths 388 389 def forward(self, enc_inputs, dec_inputs): 390 enc_outputs, enc_self_attns = self.encoder(enc_inputs) 391 dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) 392 dec_logits = self.projection(dec_outputs) 393 return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns 394 395 396 if __name__ == '__main__': 397 enc_input = [ 398 [1, 3, 4, 1, 2, 3], 399 [1, 3, 4, 1, 2, 3], 400 [1, 3, 4, 1, 2, 3], 401 [1, 3, 4, 1, 2, 3]] 402 dec_input = [ 403 [1, 0, 0, 0, 0, 0], 404 [1, 3, 0, 0, 0, 0], 405 [1, 3, 4, 0, 0, 0], 406 [1, 3, 4, 1, 0, 0]] 407 enc_input = torch.as_tensor(enc_input, dtype=torch.long).to(torch.device('cpu')) 408 dec_input = torch.as_tensor(dec_input, dtype=torch.long).to(torch.device('cpu')) 409 model = Translation( 410 src_vocab_size=5, tgt_vocab_size=5, d_model=128, 411 d_ff=256, d_k=64, d_v=64, n_heads=8, n_layers=4, src_pad_index=0, 412 tgt_pad_index=0, device=torch.device('cpu')) 413 414 logits, _, _, _ = model(enc_input, dec_input) 415 print(logits)