Attention is all you need 深入解析


 

  最近一直在看有關transformer相關網絡結構,為此我特意將經典結構 Attention is all you need 論文進行了解讀,並根據其源碼深入解讀attntion經典結構,

為此本博客將介紹如下內容:

 

論文鏈接:https://arxiv.org/abs/1706.03762

 

 一.Transformer結構與原理解釋。

第一部分介紹Attention is all you need 結構、模塊、公式。暫時不介紹什么Q K V 什么Attention 什么編解碼等,單我將會根據代碼解讀介紹,讓讀者更容易理解。

①結構: Transformer由且僅由self.Attention和Feed Forward Neural Network組成,即mutil-head-attention與FFN,如下圖。

 

 

 

 ②模塊結構:除了以上提到mutil-head-attention與FFN外,還需有個位置編碼結構positional encoding以及mask編碼模塊。

③公式:

位置編碼公式(還有很多其它公式,該論文使用此公式)

 

 

 Q K V公式

 

 

FFN基本是由nn.Linear線性和激活變化,在后面用代碼講解。

二.代碼解讀。

第二部分會從模型輸入開始,層層遞推介紹整個編碼和解碼過程、以及整個過程中使用的Attention編碼、FFN編碼、位置編碼等。

 

ENCODE模塊:

 

① 編碼輸入數據介紹:

enc_input = [
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3],
[1, 3, 4, 1, 2, 3]]
編碼使用輸入數據,為4x6行,表示4個句子,每個句子有6個單詞,包含標點符號。



② 輸入值的Embedding與位置編碼

輸入值embedding:
self.src_emb = nn.Embedding(vocab_size, d_model) # d_model=128
vocab_size:詞典的大小尺寸,比如總共出現5000個詞,那就輸入5000。此時index為(0-4999)d_model:嵌入向量的維度,即用多少維來表示一個詞或符號
隨后可將輸入x=enc_input,可將enc_outputs則表示嵌入成功,維度為[4,6,128]分別表示batch為4,詞為6,用128維度描述詞6
x = self.src_emb(x) # 詞嵌入
位置編碼:
以下使用位置編碼公式的代碼,為此無需再介紹了。
1 pe = torch.zeros(max_len, d_model)
2         position = torch.arange(0., max_len).unsqueeze(1)
3         div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))  # 偶數列
4         pe[:, 0::2] = torch.sin(position * div_term) # 奇數列
5         pe[:, 1::2] = torch.cos(position * div_term)
6         pe = pe.unsqueeze(0)

將編碼進行位置編碼后,位置為[1,6,128]+輸入編碼的[4,6,128],相當於句子已經結合了位置編碼信息,作為新新的輸入。

x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)  # torch.autograd.Variable 表示有梯度的張量變量


③self.attention的編碼:
在介紹此之前,先普及一個知識,若X與Y相等,則為self attention 否則為cross-attention,因為解碼時候X!=Y.

 

 獲取Q K V 代碼,實際是一個線性變化,將以上輸入x變成[4,6,512],然后通過head個數8與對應dv,dk將512拆分[8,64],隨后移維度位置,變成[4,8,6,64]

1 self.WQ = nn.Linear(d_model, d_k * n_heads)  # 利用線性卷積
2 self.WK = nn.Linear(d_model, d_k * n_heads)
3 self.WV = nn.Linear(d_model, d_v * n_heads)

變化后的q k v

1 q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # 線性卷積后再分組實現head功能
2 k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
3 v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
4 attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)  # 編導對應的頭

隨后通過以上self公式,將其編碼計算

1 scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
5 attn = nn.Softmax(dim=-1)(scores)
6 context = torch.matmul(attn, V)

以上編碼將是encode編碼得到結果,我們將得到結果進行還原:

1context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 將其還原
2output = self.linear(context)  # 通過線性又將其變成原來模樣維度
3layer_norm(output + Q)  # 這里加Q 實際是對Q尋找

以上將重新得到新的輸入x,維度為[4,6,128]

 

 

④ FFN編碼:

將以上的輸出維度為[4,6,128]進行FNN層變化,實際類似線性殘差網絡變化,得到最終輸出

 

 1 class PoswiseFeedForwardNet(nn.Module):
 2 
 3     def __init__(self, d_model, d_ff):
 4         super(PoswiseFeedForwardNet, self).__init__()
 5         self.l1 = nn.Linear(d_model, d_ff)
 6         self.l2 = nn.Linear(d_ff, d_model)
 7 
 8         self.relu = GELU()
 9         self.layer_norm = nn.LayerNorm(d_model)
10 
11     def forward(self, inputs):
12         residual = inputs
13         output = self.l1(inputs)  # 一層線性卷積
14         output = self.relu(output)
15         output = self.l2(output)  # 一層線性卷積
16         return self.layer_norm(output + residual)

 

⑤ 重復以上步驟編碼,即將得到經過FFN變化的輸出x,維度為[4,6,128],將其重復步驟③-④,因其編碼為6個,可重復5個便是完成相應的編碼模塊。

 

 

 

DECODE模塊:

 

①解碼輸入數據介紹,包含以下數據輸入(dec_input)、enc_input的輸入與解碼后輸出的數據,維度為[4,6,128]:

dec_input = [
[1, 0, 0, 0, 0, 0],
[1, 3, 0, 0, 0, 0],
[1, 3, 4, 0, 0, 0],
[1, 3, 4, 1, 0, 0]]


②dec_input的Embedding與位置編碼
因其與encode的實現方法一致,只需將enc_input使用dec_input取代,得到dec_outputs,因此這里將不在介紹。

③mask編碼,包含整體編碼與局部編碼
整體編碼,代碼如下:
1 def get_attn_pad_mask(seq_q, seq_k, pad_index):
2     batch_size, len_q = seq_q.size()
3     batch_size, len_k = seq_k.size()
4     pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
5     pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
6     return pad_attn_mask.expand(batch_size, len_q, len_k)

以上代碼實際是將dec_input進行處理,實際變成以下數據:

   [[0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1]]

將其增添維度為[4,1,6],並將其擴張為[4,6,6]

局部代碼編寫,實際為上三角矩陣:

[[0. 1. 1. 1. 1. 1.]
[0. 0. 1. 1. 1. 1.]
[0. 0. 0. 1. 1. 1.]
[0. 0. 0. 0. 1. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]]
將以上數據添加維度為[1,6,6],在將擴展變成[4,6,6]
關於整體mask與局部mask編碼,我的理解是整體信息為語句4個詞6個,根據解碼輸入編碼整體信息,而局部編碼是基於一個語句6*6編碼信息,將其擴張重復到4個語句,
使其mask獲得整體信息與局部信息。
1         dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)  # 整體編碼的mask
2         dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
3         dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)  # torch.gt(a,b) a>b 則為1否則為0
4         dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)

最終將mask整合,獲取dec_self_attn_mask信息,同理dec_enc_attn_mask(維度為解碼編碼詞維度)采用dec_self_attn_mask的第一步便可獲取。

 

④編碼輸入self-Attention,包含2部分

解碼輸入dec_outputs進行self.Attention:

實際使用以上Q K V公式,具體實現和編碼實現方法一致,唯一不同是

在Q*KT會使用解碼maskdec_self_attn_mask,其重要代碼為scores.masked_fill_(attn_mask, -1e9),其它代碼為:

 1 class ScaledDotProductAttention(nn.Module):
 2 
 3     def __init__(self, d_k, device):
 4         super(ScaledDotProductAttention, self).__init__()
 5         self.device = device
 6         self.d_k = d_k
 7 
 8     def forward(self, Q, K, V, attn_mask):
 9         scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
10         attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
11         attn_mask = attn_mask.to(self.device)
12         scores.masked_fill_(attn_mask, -1e9)  # it is true give -1e9
13         attn = nn.Softmax(dim=-1)(scores)
14         context = torch.matmul(attn, V)
15         return context, attn

 以上代碼將執行以下代碼:

context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
attn_mask=attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # 將其還原
output = self.linear(context) # 通過線性又將其變成原來模樣維度
dec_outputs = self.layer_norm(output + Q) # 這里加Q 實際是對Q尋找

 到此為止已經完成了解碼輸入的self-attention模塊,輸出為dec_outputs實際除了增加mask編碼調整Q*KT以外,其它完全相同。

編碼輸出dec_outputs進行Cross Attention:

dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) # 重點說明enc_outputs來源編碼結果,是一直不變的
以上為Cross Attention 過程,以上代碼除了Q來源dec_outputs,K V 來源編碼輸出enc_outputs以外,即論文所說X與Y不等得到的Q K V稱為Cross Attention。
實際以上代碼與執行解碼self-Attention方法完全一致,僅僅mask更改上文提供的方法,得到輸出結果為dec_outputs,因此這里將不在解釋了。


⑤ FFN編碼。
通過④的attention編碼,得到dec_outputs后,采用編碼步驟④的FNN方法。



⑥ 重復步驟④-⑤多次,便實現了解碼過程。

至此,本文已完全解讀完Attention is all you need的編碼與解碼結構。
 

 個人重點總結:

①未使用通常kernel=3的CNN卷積,而所有均使用Linear卷積;

②編碼傳遞K V 解碼傳遞Q;

③self-attention 和 cross attention本質是X與Y值不同,即得到Q 和 K V 數據來源不同,但實現方法一致;

④ transformer重點模塊為attention(一般是mutil-head attention)、FFN、位置編碼、mask編碼;

 

 

 最后貼上完整代碼,便於讀者深入理解:

整體代碼:

  1 import json
  2 import math
  3 import torch
  4 import torchvision
  5 import torch.nn as nn
  6 import numpy as np
  7 from pdb import set_trace
  8 
  9 from torch.autograd import Variable
 10 
 11 
 12 def get_attn_pad_mask(seq_q, seq_k, pad_index):
 13     batch_size, len_q = seq_q.size()
 14     batch_size, len_k = seq_k.size()
 15     pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
 16     pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
 17     return pad_attn_mask.expand(batch_size, len_q, len_k)
 18 
 19 
 20 def get_attn_subsequent_mask(seq):
 21     attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
 22     subsequent_mask = np.triu(np.ones(attn_shape), k=1)
 23     subsequent_mask = torch.from_numpy(subsequent_mask).int()
 24     return subsequent_mask
 25 
 26 
 27 class GELU(nn.Module):
 28 
 29     def forward(self, x):
 30         return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
 31 
 32 
 33 class PositionalEncoding(nn.Module):
 34     "Implement the PE function."
 35 
 36     def __init__(self, d_model, dropout, max_len=5000):  #
 37         super(PositionalEncoding, self).__init__()
 38         self.dropout = nn.Dropout(p=dropout)
 39 
 40         # Compute the positional encodings once in log space.
 41         pe = torch.zeros(max_len, d_model)
 42         position = torch.arange(0., max_len).unsqueeze(1)
 43         div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))  # 偶數列
 44         pe[:, 0::2] = torch.sin(position * div_term)
 45         pe[:, 1::2] = torch.cos(position * div_term)
 46         pe = pe.unsqueeze(0)
 47         self.register_buffer('pe', pe)  # 將變量pe保存到內存中,不計算梯度
 48 
 49     def forward(self, x):
 50         x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)  # torch.autograd.Variable 表示有梯度的張量變量
 51         return self.dropout(x)
 52 
 53 
 54 class ScaledDotProductAttention(nn.Module):
 55 
 56     def __init__(self, d_k, device):
 57         super(ScaledDotProductAttention, self).__init__()
 58         self.device = device
 59         self.d_k = d_k
 60 
 61     def forward(self, Q, K, V, attn_mask):
 62         scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
 63         attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
 64         attn_mask = attn_mask.to(self.device)
 65         scores.masked_fill_(attn_mask, -1e9)  # it is true give -1e9
 66         attn = nn.Softmax(dim=-1)(scores)
 67         context = torch.matmul(attn, V)
 68         return context, attn
 69 
 70 
 71 class MultiHeadAttention(nn.Module):
 72 
 73     def __init__(self, d_model, d_k, d_v, n_heads, device):
 74         super(MultiHeadAttention, self).__init__()
 75         self.WQ = nn.Linear(d_model, d_k * n_heads)  # 利用線性卷積
 76         self.WK = nn.Linear(d_model, d_k * n_heads)
 77         self.WV = nn.Linear(d_model, d_v * n_heads)
 78 
 79         self.linear = nn.Linear(n_heads * d_v, d_model)
 80 
 81         self.layer_norm = nn.LayerNorm(d_model)
 82         self.device = device
 83 
 84         self.d_model = d_model
 85         self.d_k = d_k
 86         self.d_v = d_v
 87         self.n_heads = n_heads
 88 
 89     def forward(self, Q, K, V, attn_mask):
 90         batch_size = Q.shape[0]
 91         q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # 線性卷積后再分組實現head功能
 92         k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
 93         v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
 94 
 95         attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)  # 編導對應的頭
 96         context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
 97                                                                                     attn_mask=attn_mask)
 98         context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 將其還原
 99         output = self.linear(context)  # 通過線性又將其變成原來模樣維度
100         return self.layer_norm(output + Q), attn  # 這里加Q 實際是對Q尋找
101 
102 
103 class PoswiseFeedForwardNet(nn.Module):
104 
105     def __init__(self, d_model, d_ff):
106         super(PoswiseFeedForwardNet, self).__init__()
107         self.l1 = nn.Linear(d_model, d_ff)
108         self.l2 = nn.Linear(d_ff, d_model)
109 
110         self.relu = GELU()
111         self.layer_norm = nn.LayerNorm(d_model)
112 
113     def forward(self, inputs):
114         residual = inputs
115         output = self.l1(inputs)  # 一層線性卷積
116         output = self.relu(output)
117         output = self.l2(output)  # 一層線性卷積
118         return self.layer_norm(output + residual)
119 
120 
121 class EncoderLayer(nn.Module):
122 
123     def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
124         super(EncoderLayer, self).__init__()
125         self.enc_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
126         self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
127 
128     def forward(self, enc_inputs, enc_self_attn_mask):
129         enc_outputs, attn = self.enc_self_attn(Q=enc_inputs, K=enc_inputs, V=enc_inputs, attn_mask=enc_self_attn_mask)
130         # X=Y 因此Q K V相等
131         enc_outputs = self.pos_ffn(enc_outputs)  #
132         return enc_outputs, attn
133 
134 
135 class Encoder(nn.Module):
136 
137     def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
138         #                   4        128     256   64   64     8        4          0
139         super(Encoder, self).__init__()
140         self.device = device
141         self.pad_index = pad_index
142         self.src_emb = nn.Embedding(vocab_size, d_model)
143         # vocab_size:詞典的大小尺寸,比如總共出現5000個詞,那就輸入5000。此時index為(0-4999) d_model:嵌入向量的維度,即用多少維來表示一個符號
144         self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
145 
146         self.layers = []
147         for _ in range(n_layers):
148             encoder_layer = EncoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
149             self.layers.append(encoder_layer)
150         self.layers = nn.ModuleList(self.layers)
151 
152     def forward(self, x):
153         enc_outputs = self.src_emb(x)  # 詞嵌入
154         enc_outputs = self.pos_emb(enc_outputs)  # pos+matx
155         enc_self_attn_mask = get_attn_pad_mask(x, x, self.pad_index)
156 
157         enc_self_attns = []
158         for layer in self.layers:
159             enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
160             enc_self_attns.append(enc_self_attn)
161 
162         enc_self_attns = torch.stack(enc_self_attns)
163         enc_self_attns = enc_self_attns.permute([1, 0, 2, 3, 4])
164         return enc_outputs, enc_self_attns
165 
166 
167 class DecoderLayer(nn.Module):
168 
169     def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
170         super(DecoderLayer, self).__init__()
171         self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
172         self.dec_enc_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
173         self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
174 
175     def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
176         dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
177         dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
178         dec_outputs = self.pos_ffn(dec_outputs)
179         return dec_outputs, dec_self_attn, dec_enc_attn
180 
181 
182 class Decoder(nn.Module):
183 
184     def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
185         super(Decoder, self).__init__()
186         self.pad_index = pad_index
187         self.device = device
188         self.tgt_emb = nn.Embedding(vocab_size, d_model)
189         self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
190         self.layers = []
191         for _ in range(n_layers):
192             decoder_layer = DecoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
193             self.layers.append(decoder_layer)
194         self.layers = nn.ModuleList(self.layers)
195 
196     def forward(self, dec_inputs, enc_inputs, enc_outputs):
197         dec_outputs = self.tgt_emb(dec_inputs)
198         dec_outputs = self.pos_emb(dec_outputs)
199 
200         dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
201         dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
202         dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
203         dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)
204 
205         dec_self_attns, dec_enc_attns = [], []
206         for layer in self.layers:
207             dec_outputs, dec_self_attn, dec_enc_attn = layer(
208                 dec_inputs=dec_outputs,
209                 enc_outputs=enc_outputs,
210                 dec_self_attn_mask=dec_self_attn_mask,
211                 dec_enc_attn_mask=dec_enc_attn_mask)
212             dec_self_attns.append(dec_self_attn)
213             dec_enc_attns.append(dec_enc_attn)
214         dec_self_attns = torch.stack(dec_self_attns)
215         dec_enc_attns = torch.stack(dec_enc_attns)
216 
217         dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
218         dec_enc_attns = dec_enc_attns.permute([1, 0, 2, 3, 4])
219 
220         return dec_outputs, dec_self_attns, dec_enc_attns
221 
222 
223 class MaskedDecoderLayer(nn.Module):
224 
225     def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
226         super(MaskedDecoderLayer, self).__init__()
227         self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
228         self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
229 
230     def forward(self, dec_inputs, dec_self_attn_mask):
231         dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
232         dec_outputs = self.pos_ffn(dec_outputs)
233         return dec_outputs, dec_self_attn
234 
235 
236 class MaskedDecoder(nn.Module):
237 
238     def __init__(self, vocab_size, d_model, d_ff, d_k,
239                  d_v, n_heads, n_layers, pad_index, device):
240         super(MaskedDecoder, self).__init__()
241         self.pad_index = pad_index
242         self.tgt_emb = nn.Embedding(vocab_size, d_model)
243         self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
244 
245         self.layers = []
246         for _ in range(n_layers):
247             decoder_layer = MaskedDecoderLayer(
248                 d_model=d_model, d_ff=d_ff,
249                 d_k=d_k, d_v=d_v, n_heads=n_heads,
250                 device=device)
251             self.layers.append(decoder_layer)
252         self.layers = nn.ModuleList(self.layers)
253 
254     def forward(self, dec_inputs):
255         dec_outputs = self.tgt_emb(dec_inputs)
256         dec_outputs = self.pos_emb(dec_outputs)
257 
258         dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
259         dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
260         dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
261         dec_self_attns = []
262         for layer in self.layers:
263             dec_outputs, dec_self_attn = layer(
264                 dec_inputs=dec_outputs,
265                 dec_self_attn_mask=dec_self_attn_mask)
266             dec_self_attns.append(dec_self_attn)
267         dec_self_attns = torch.stack(dec_self_attns)
268         dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
269         return dec_outputs, dec_self_attns
270 
271 
272 class BertModel(nn.Module):
273 
274     def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
275         super(BertModel, self).__init__()
276         self.tok_embed = nn.Embedding(vocab_size, d_model)
277         self.pos_embed = PositionalEncoding(d_model=d_model, dropout=0)
278         self.seg_embed = nn.Embedding(2, d_model)
279 
280         self.layers = []
281         for _ in range(n_layers):
282             encoder_layer = EncoderLayer(
283                 d_model=d_model, d_ff=d_ff,
284                 d_k=d_k, d_v=d_v, n_heads=n_heads,
285                 device=device)
286             self.layers.append(encoder_layer)
287         self.layers = nn.ModuleList(self.layers)
288 
289         self.pad_index = pad_index
290 
291         self.fc = nn.Linear(d_model, d_model)
292         self.active1 = nn.Tanh()
293         self.classifier = nn.Linear(d_model, 2)
294 
295         self.linear = nn.Linear(d_model, d_model)
296         self.active2 = GELU()
297         self.norm = nn.LayerNorm(d_model)
298 
299         self.decoder = nn.Linear(d_model, vocab_size, bias=False)
300         self.decoder.weight = self.tok_embed.weight
301         self.decoder_bias = nn.Parameter(torch.zeros(vocab_size))
302 
303     def forward(self, input_ids, segment_ids, masked_pos):
304         output = self.tok_embed(input_ids) + self.seg_embed(segment_ids)
305         output = self.pos_embed(output)
306         enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.pad_index)
307 
308         for layer in self.layers:
309             output, enc_self_attn = layer(output, enc_self_attn_mask)
310 
311         h_pooled = self.active1(self.fc(output[:, 0]))
312         logits_clsf = self.classifier(h_pooled)
313 
314         masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
315         h_masked = torch.gather(output, 1, masked_pos)
316         h_masked = self.norm(self.active2(self.linear(h_masked)))
317         logits_lm = self.decoder(h_masked) + self.decoder_bias
318 
319         return logits_lm, logits_clsf, output
320 
321 
322 class GPTModel(nn.Module):
323 
324     def __init__(self, vocab_size, d_model, d_ff,
325                  d_k, d_v, n_heads, n_layers, pad_index,
326                  device):
327         super(GPTModel, self).__init__()
328         self.decoder = MaskedDecoder(
329             vocab_size=vocab_size,
330             d_model=d_model, d_ff=d_ff,
331             d_k=d_k, d_v=d_v, n_heads=n_heads,
332             n_layers=n_layers, pad_index=pad_index,
333             device=device)
334         self.projection = nn.Linear(d_model, vocab_size, bias=False)
335 
336     def forward(self, dec_inputs):
337         dec_outputs, dec_self_attns = self.decoder(dec_inputs)
338         dec_logits = self.projection(dec_outputs)
339         return dec_logits, dec_self_attns
340 
341 
342 class Classifier(nn.Module):
343 
344     def __init__(self, vocab_size, d_model, d_ff,
345                  d_k, d_v, n_heads, n_layers,
346                  pad_index, device, num_classes):
347         super(Classifier, self).__init__()
348         self.encoder = Encoder(
349             vocab_size=vocab_size,
350             d_model=d_model, d_ff=d_ff,
351             d_k=d_k, d_v=d_v, n_heads=n_heads,
352             n_layers=n_layers, pad_index=pad_index,
353             device=device)
354         self.projection = nn.Linear(d_model, num_classes)
355 
356     def forward(self, enc_inputs):
357         enc_outputs, enc_self_attns = self.encoder(enc_inputs)
358         mean_enc_outputs = torch.mean(enc_outputs, dim=1)
359         logits = self.projection(mean_enc_outputs)
360         return logits, enc_self_attns
361 
362 
363 class Translation(nn.Module):
364 
365     def __init__(self, src_vocab_size, tgt_vocab_size, d_model,
366                  d_ff, d_k, d_v, n_heads, n_layers, src_pad_index,
367                  tgt_pad_index, device):
368         super(Translation, self).__init__()
369         self.encoder = Encoder(
370             vocab_size=src_vocab_size,  # 5
371             d_model=d_model, d_ff=d_ff,  # 128  256
372             d_k=d_k, d_v=d_v, n_heads=n_heads,  # 64 64  8
373             n_layers=n_layers, pad_index=src_pad_index,  # 4  0
374             device=device)
375         self.decoder = Decoder(
376             vocab_size=tgt_vocab_size,  # 5
377             d_model=d_model, d_ff=d_ff,  # 128  256
378             d_k=d_k, d_v=d_v, n_heads=n_heads,  # 64 64  8
379             n_layers=n_layers, pad_index=tgt_pad_index,  # 4  0
380             device=device)
381         self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
382 
383     # def forward(self, enc_inputs, dec_inputs, decode_lengths):
384     #     enc_outputs, enc_self_attns = self.encoder(enc_inputs)
385     #     dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
386     #     dec_logits = self.projection(dec_outputs)
387     #     return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns, decode_lengths
388 
389     def forward(self, enc_inputs, dec_inputs):
390         enc_outputs, enc_self_attns = self.encoder(enc_inputs)
391         dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
392         dec_logits = self.projection(dec_outputs)
393         return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns
394 
395 
396 if __name__ == '__main__':
397     enc_input = [
398         [1, 3, 4, 1, 2, 3],
399         [1, 3, 4, 1, 2, 3],
400         [1, 3, 4, 1, 2, 3],
401         [1, 3, 4, 1, 2, 3]]
402     dec_input = [
403         [1, 0, 0, 0, 0, 0],
404         [1, 3, 0, 0, 0, 0],
405         [1, 3, 4, 0, 0, 0],
406         [1, 3, 4, 1, 0, 0]]
407     enc_input = torch.as_tensor(enc_input, dtype=torch.long).to(torch.device('cpu'))
408     dec_input = torch.as_tensor(dec_input, dtype=torch.long).to(torch.device('cpu'))
409     model = Translation(
410         src_vocab_size=5, tgt_vocab_size=5, d_model=128,
411         d_ff=256, d_k=64, d_v=64, n_heads=8, n_layers=4, src_pad_index=0,
412         tgt_pad_index=0, device=torch.device('cpu'))
413 
414     logits, _, _, _ = model(enc_input, dec_input)
415     print(logits)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM