d2l中的valid_lens用法解析


X的維度: torch.Size([64, 10, 32]) batchsz=64,seq_len=10,dim=32。
其實很好理解啊,X的維度是[64, 10, 32];所以valid_lens要mask它啊,所以,肯定是(64,10),現在 裂變成4個head,所以就是(256,10)。

part0

valid_lens的維度:[64],每個句子對應一個。從數據集中取出來就是如此。

注意Mulitiheadattention的forward函數中,注意X的這個過程:把dim裂開維(num_heads,dim);然后num_heads提前、和batchsz合並,num_heads=4,故X變成了(256,10,8)。
在Mulitiheadattention的過attention之前,valid_lens擴增了頭數:

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for
            # `num_heads` times, then copy the next item, and so on
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads,
                                                 dim=0)

此時,valid_lens的維度:[256]。只是復制了4個頭。

【0】最終使用時,valid_lens到了這里:

class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # Shape of `queries`: (`batch_size`, no. of queries, `d`)
    # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
    # Shape of `values`: (`batch_size`, no. of key-value pairs, value
    # dimension)
    # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Set `transpose_b=True` to swap the last two dimensions of `keys`
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

self.attention_weights = masked_softmax(scores, valid_lens),
此時其形狀 還是torch.Size([256]);

def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape               #XDotProductAttention中的scores,表示單頭注意力的分數。shape: torch.Size([256, 10, 10])
        if valid_lens.dim() == 1:     #若為1維【3,4,5,4,3】這樣,長度為256;
            valid_lens = torch.repeat_interleave(valid_lens, shape[1]) 
                                      #則需要擴增shape[1]倍,也就是序列長度倍num_steps. 若為自注意力,確實應該是每個單詞都占一行,
                                      #按理說,過程是這樣的:[256]變成[256,10],然后經歷else中的過程變為[2560]。 [256]變成[2560]。  
                                      #編碼器過程中,用【3,4,5,4,3】這樣直接復制10份,是因為每一個3都表示:第i行對其他值的注意力程度。每一行都代表一個單詞,所以每個單詞都 
                                      #對應着一樣的長度序列,也就是這句完成的話,這里就是復制了10個3,[3*10]里這都表示的是對一個句子里的單詞的注意力。至於為什么是10不是3,為 
                                      #了處理方便吧;但在mask時會把第3行之后的都mask掉。
        else:                         #【1】進入這里
            valid_lens = valid_lens.reshape(-1) # [256,10] 變為[2560] #這是由於,decoder的valid_lens生成的時候就是矩陣那樣生成的。[1,2,..10],[1,2..,10]這樣,所以每一個剛剛好對應了10步。表示了[1,2,..10]中,預測的第一條單詞只對1前的輸入有注意力,所以也是一樣的:第i行對所有的其他單詞的注意力,但此時不能往后看;所以分別為[1,2,..10]。這里一個[1,2,..10]表示的是第i個單詞對一個句子里的單詞的注意力,相當於query_key乘積的scores的scores[i]第i行,因此每一個值scores[i][j](如為3)表示這一行能看多少個有效單詞。
        
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)# 對X [256, 10, 10]變為[2560,10],對其無效的位置填充上-1e6
        return nn.functional.softmax(X.reshape(shape), dim=-1) #填完后,恢復[256, 10, 10],之后對最后一維softmax

注釋1:

def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences."""
    maxlen = X.size(1)                                                     # 10
    mask = torch.arange((maxlen), dtype=torch.float32,                     #
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

part1

【1】而decoder的forward函數中定義了dec_valid_lens的求法,它是使用repeat函數,將repeat的參數[1,..,10]作為元素t,擴充為形狀為(batch_size,1)的矩陣,[[t],[t]];
因此dec_valid_lens維度變為(64,10),具體的,就是[[1,2,3,..10],..,[1,2,..,10]]:

            dec_valid_lens = torch.arange(1, num_steps + 1,
                                          device=X.device).repeat(
                                              batch_size, 1) #<BOS>平移; 按時間步mask,dec_valid_lens 輸出見下 batchsz=64 num_steps=10

與part0相同,注意Mulitiheadattention的forward函數中,X的這個過程:把dim裂開維(num_heads,dim);然后num_heads提前、和batchsz合並,num_heads=4,故X變成了(256,10,8)。
在Mulitiheadattention的過attention之前,valid_lens也擴增了頭數:

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for
            # `num_heads` times, then copy the next item, and so on
            valid_lens = torch.repeat_interleave(valid_lens,
                                                 repeats=self.num_heads,
                                                 dim=0)

此時,由於在dim0上操作,valid_lens的維度:(64,10)就變為 [256,10]。【因為在dim0上復制num_heads=4次】
valid_lens在這里為 torch.Size([256, 10]),繼續看下面的函數。

def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape               #XDotProductAttention中的scores,表示單頭注意力的分數。shape: torch.Size([256, 10, 10])
        if valid_lens.dim() == 1:     #若為1維【3,4,5,4,3】這樣,長度為256;
            valid_lens = torch.repeat_interleave(valid_lens, shape[1]) #則需要擴增shape[1]倍,也就是序列長度倍num_steps. 若為自注意力,確實應該是每個單詞都占一行,
                                           #按理說,過程是這樣的:[256]變成[256,10],然后經歷else中的過程變為[2560]。 [256]變成[2560]。  
                                           #編碼器過程中,用【3,4,5,4,3】這樣直接復制10份,是因為每一個3都表示:第i行對其他值的注意力程度。每一行都代表一個單詞,所以每個單詞都對應着一樣的長度序列,也就是這句完成的話,這里就是復制了10個3,[3*10]里這都表示的是對一個句子里的單詞的注意力。至於為什么是10不是3,為了處理方便吧;但在mask時會把第3行之后的都mask掉。
        else:                         #【1】進入這里
            valid_lens = valid_lens.reshape(-1) # [256,10] 變為[2560] #這是由於,decoder的valid_lens生成的時候就是矩陣那樣生成的。[1,2,..10],[1,2..,10]這樣,所以每一個剛剛好對應了10步。表示了[1,2,..10]中,預測的第一條單詞只對1前的輸入有注意力,所以也是一樣的:第i行對所有的其他單詞的注意力,但此時不能往后看;所以分別為[1,2,..10]。這里一個[1,2,..10]表示的是第i個單詞對一個句子里的單詞的注意力,相當於query_key乘積的scores的scores[i]第i行,因此每一個值scores[i][j](如為3)表示這一行能看多少個有效單詞。
        
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)# 對X [256, 10, 10]變為[2560,10],對其無效的位置填充上-1e6
        return nn.functional.softmax(X.reshape(shape), dim=-1) #填完后,恢復[256, 10, 10],之后對最后一維softmax

注釋1:

def sequence_mask(X, valid_len, value=0):
    """Mask irrelevant entries in sequences."""
    maxlen = X.size(1)                                                     # 10
    mask = torch.arange((maxlen), dtype=torch.float32,                     #這里注釋1,它只會
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

torch.arange((maxlen))=[1,2,..,10]
[[0,1,2,..,9]]<[[1],[2],..,[10]]得到的就是主對角線(不包含包含)以上的右上半角全為False,主對角線(包含)以下全為True的下三角矩陣。這樣,上三角(不含對角線)全為-1e6。softmax后全為0。
合理!因為第i行能看到第i個輸入。如第0行表示,第0個單詞能看到(且只能看到)句子中第0個token ,故其在[0][0]處注意力分數存在且等於1。

但是這里是否有漏洞?最后一行的最后幾個字符就一定不是padding的嗎?不是吧。
所以應該再用encoder的valid_lens再過一層mask。取兩者的交集。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM