【注意力機制】Attention Augmented Convolutional Networks


注意力機制之Attention Augmented Convolutional Networks

原始鏈接:https://www.yuque.com/lart/papers/aaconv

image.png

核心內容

We propose to augment convolutional operators with this self-attention mechanism by concatenating convolutional feature maps with a set of feature maps produced via self-attention.

主要工作

首先了解卷積操作本身兩點特性:

盡管這些屬性被證明了是設計在圖像上操作的模型時至關重要的歸納偏置(inductive biase). 但是卷積的局部性質(the local nature of the convolutional kernel)阻礙了其捕獲全局的上下文信息(global context), 而這些信息對於圖像識別是很必要的. 這是卷積的重要的弱點. (convolution operator is limited by its locality and lack of understandingof global contexts)

而在捕獲長距離交互關系(long range interaction)上, 最近的Self-attention表現的很不錯(has emerged as a recent advance). 自注意力背后的關鍵思想是生成從隱藏單元計算的值的加權平均值. 不同於卷積操作或者池化操作, 這些權重是動態的根據輸入特征, 通過隱藏單元之間的相似性函數產生的(produced dynamically via a similarity function between hidden units). 因此輸入信號之間的交互依賴於信號本身, 而不是像在卷積中, 被預先由他們的相對位置而決定.

所以本文嘗試將自注意力計算應用到卷積操作中, 來實現長距離交互. 在判別性視覺任務(discriminative visual tasks)中, 考慮使用自注意力替換普通的卷積. 引入a novel two-dimensional relative self-attention mechanism, 其在注入(being infused with)相對位置信息的同時可以保持translation equivariance, 使其非常適合圖像.

在取代卷積作為獨立計算單元方面被證明是有競爭力的. 但是需要注意的是, 在控制實驗中發現, 將自注意力和卷積組合起來的情況可以獲得最好的結果. 因此並沒有完全拋棄卷積, 而是提出使用self-attention mechanism來增強卷積(augment convolutions), 即將強調局部性的卷積特征圖和基於self-attention產生的能夠建模更長距離依賴(capable of modeling longer range dependencies)的特征圖拼接來獲得最終結果.

在多個實驗中, 注意力增強卷積都實現了一致的提升, 另外對於完全的自注意模型(不用卷積那部分), 這可以看作是注意力增強模型的一種特殊情況, 在ImageNet上僅比它們的完全卷積結構略差, 這表明自注意機制是一種用於圖像分類的強大獨立的計算原語(a powerful standalone computational primitive).

關於primitive這個概念, 找到了一段解釋: 大意是指整個系統中最基本的概念.
https://stackoverflow.com/a/8022435
For me, it means something that cannot be decomposed (people use also the atomic word sometimes in that sense, but atomic is often also used for explanation on concurrency or parallelism with a different meaning).​
For instance, on Unix (or Linux) the system calls, as seen by the application are primitive or atomic, they either happen or not (sometimes, they got interrupted and give an EINTR or ERESTART error).
And inside an interpreter, or even in the formal specification, of a language, the primitive are those operations which you cannot define, and which the interpreter deals with specially. Very often, cons is a primitive operation for Lisp dialects.

這里提到了其他的一些visual tasks中的注意力的工作:

相對於現有的方法, 這里要提出的結構不依賴於對應的(counterparts)完全卷積模型的預訓練, 而是整個網絡都使用了self-attention mechanism. 另外multi-head attention的使用使得模型同時關注空間子空間和特征子空間. (多頭注意力就是將特征划沿着通道划分為不同的組, 不同組內進行單獨的變換, 可以獲得更加多樣化的特征表達)

另外, 為了增強圖像上的自注意力的表達能力, 這里擴展[Selfattention with relative position representations,  Music transformer]中的相對自注意力到二維形式, 這使得可以以有原則(in a principled way)地模擬平移等變性(translation equivariance).

這樣的結構可以直接產生額外的特征圖, 而不是通過加法(可能是乘法)[Non-local neural networks,  Self-attention generative adversarial networks]或門控[Squeeze-and-excitation networks, Gather-excite: Exploiting feature context in convolutional neural networks, Bam: bottleneck attention module, Cbam: Convolutional block attention module]重新校准卷積特征. 這一特性允許靈活地調整注意力通道的比例, 考慮從完全卷積到完全注意模型的一系列架構(a spectrum of architectures, ranging from fully convolutional to fully attentional models).

主要結構

image.png

  • H, W, Fin: 輸入特征圖的height, weight, 通道數
  • Nh, dv, dk:heads的數量, values的深度(也就是特征圖通道數), queries和keys的深度(這幾個參數都是MHA, multi-head attention的一些參數), 這里有要求, dv和dk必須可以被Nh整除, 這里使用dhv和dhk來作為每個head中值的深度和查詢/鍵的深度

圖像數據多頭注意力的計算

image.png

單頭的計算形式

image.png

多頭是由單頭拼接而成

  1. in_tensor\((H,W,F_{in})\) =(flatten)=> X\((HW,F_{in})\)(We omit the batch dimension for simplicity.)
  2. 按照transformer結構結算多頭注意力
    1. 對於head h對應的自注意力結果為式子1所示, 這里的\(W_q\)/\(W_k\)/\(W_v\)分別形狀為\((F_{in}, d^h_q)/(F_{in}, d^h_k)/(F_{in}, d^h_v)\), 分別用於映射輸入X到查詢\(Q=XW_q\) 、鍵\(K=XW_k\) 和值\(V=XW_v\) , 分別的形狀為\((HW, d^h_q)/(HW, d^h_k)/(HW, d^h_v)\)
    2. 所有head的輸出拼接到一起, 然后按照式子2進行處理, 這里的\(W^O \in \mathbb{R}^{d_v \times d_v}\)(可以知道, 這里的\(N_h\)\(O\)的拼接, 實際上深度為\(d_v\), 也就是\(d_v=N_h \times d^h_v\)), 這里MHA計算后會調整形狀為\((H, W, d_v)\)來匹配原始的空間維度
    3. multi-head attention
      1. 計算復雜度:\(O((HW)^2d_k)\)(這里只需要考慮大頭\((XW_q)(XW_k)^T\)的計算)
      2. 空間復雜度:\(O((HW)^2N_h)\)(這里包含了Nh個頭的結果)

二維位置嵌入Two-dimensional Positional Embeddings

這里的"二維"實際上是相對於原始針對語言的一維信息的結構而言, 這里輸入的是二維圖像數據.

由於沒有顯式的位置信息的利用, 所以自注意力滿足交換律:\(MHA(\pi(X))=\pi(MHA(X))\), 這里的\(\pi\)表示對於像素位置的任意置換. 這反映出來self-attention具有 permutation equivariant. 這樣的性質使得對於模擬高度結構化的數據(例如圖像)而言, 不是很有效.

多個使用顯式的空間信息來增強激活圖的位置編碼已經被提出來處理相關的問題:

  1. Image Transformer extends the sinusoidal waves first introduced in the original Transformer to 2 dimensional inputs.
  2. CoordConv concatenates positional channels to an activation map.

在文章的實驗中發現, 在圖像分類和目標檢測上, 這些編碼方法並不好用, 作者們將其歸因於雖然這些策略可以打破置換等變性, 但是卻不能保證圖像任務需要的平移等變性(permutation equivariant(置換等變性), translation equivariance(平移等變性)). 為此, 這里擴展了現有的相對位置編碼[Self attention with relative position representations]到二維上, 並且基於Music Transformer提出一個內存有效的實現.

相對位置嵌入Relative positional embeddings

Introduced in [Self attention with relative position representations] for the purpose of language modeling, relative self-attention augments self-attention with relative position encodings and enables translation equivariance while preventing permutation equivariance.

這里通過獨立添加相對的寬和相對的高的信息, 來實現二維相對自注意力.
對於像素\(i=(i_x, i_y)\)關於像素\(j=(j_x, j_y)\)的attention logit計算方式如下(The attention logit for how much pixel i attends to pixel j is computed as):

image.png

  • \(q_i\)表示 位置為\(i\) 的query vector, 也就是Q中的一個長為\(d^h_k\)的矢量元素.
  • \(k_j\)表示 位置為\(j\) 的key vector, 也就是K中的一個長為\(d^h_k\)的矢量元素.
  • \(r^W_{j_x-i_x}\)\(r^H_{j_y-i_y}\)表示對於相對寬度\(j_x-i_x\)和相對高度\(j_y-i_y\)學習到的嵌入表示, 各自均為dhk長度的矢量.
  • \(r\)對應的相對位置參數矩陣\(r^W\)\(r^H\)分別是\((2W-1, d^h_k)\)\((2H-1, d^h_k)\)大小的.

單個頭h的輸出變成了:

image.png

這里的兩個\(S\)都是\(HW \times HW\)的矩陣, 表示沿着寬高維度的相對位置logits

  • image.png
  • image.png

因為考慮相對寬高信息, 所以滿足\(S^{rel}_W[i, j]=S^{rel}_W[i, j+W]\),\(S^{rel}_H[i, j]=S^{rel}_H[i, j+H]\). 這樣就不需要為所有的(i, j)對計算logits了, 這里可以按照這樣來理解(這是我自己的理解): 對於二維矩陣, 按照沿着行為W方向(橫向), 也即是x方向, 沿着列為H方向(縱向)即y向, 對於任意一點\(j\)和固定的點\(i\):

  • SW中有\((j_x-i_x)\%W=[(j+nW)_x-i_x]\%W\), 即按照行主序向后移動個位置, 仍位於同一列;
  • SH中有\((j_y-i_y)\%H=[(j+nH)_y-i_x]\%H\), 即按照列主序向后移動\(nH\)個位置, 依然在同一行.

這里的相對注意力的形式實際上不同於原始參考論文Self attention with relative position representations中具有內存占用為\(O((HW)^2d^h_k)\)(相對嵌入\(r_{ij} \in \mathbb{R}^{HW \times HW \times d^h_k}\))的設計, 而是基於MUSIC TRANSFORMER中提出的memory efficient relative masked attention algorithm的一種2D擴展, 擴展為了unmasked relative self-attention over 2 dimensional inputs上, 從而存儲消耗變成了\(O(HWd^h_k)\)(相對位置嵌入\(r_{ij}\)被拆分成兩個部分, 即\(r^H \in \mathbb{R}^{(2H-1) \times d^h_k}, r^W \in \mathbb{R}^{(2W-1 )\times d^h_k}\), 並且跨頭不跨層的形式進行共享). 對於每層, 實際上只需要添加額外的\((2(H + W) − 2)d^h_k\)個參數來建模沿着高和寬的相對距離即可.

Attention Augmented Convolution

文章提出的使用注意力增強的卷積主要的優勢:

  1. use an attention mechanism that can attend jointly to spatial and feature subspaces (each head corresponding to a feature subspace)
  2. introduce additional feature maps rather than refining them

AAConv的主要過程:

image.png

Similarly to the convolution, the proposed attention augmented convolution

  1. is equivariant to translation
  2. can readily operate on inputs of different spatial dimensions

接下來對標一般的卷積\((F_{out}, F_{in}, k, k)\)分析了AAConv的參數量:

  • 設置\(v=\frac{d_v}{F_{out}}\)作為MHA部分的總輸出通道數與總的AAConv輸出通道數的比值;
  • 設置\(\kappa = \frac{d_k}{F_{out}}\)作為MHA中Key的深度與總的AAConv輸出通道數的比值.
  • 使用\(1 \times 1\)卷積來線性變換得到Q\K\V, 所以有參數量\((d_v+d_k+d_q)F_{in} = (2d_k+d_v)F_{in}=(v+2\kappa)F_{out}F_{in}\)
  • 使用一個額外的\(1\times1\)卷積用於混合多個頭的貢獻(mix the contribution of different heads), 這部分參數量為\(d_vd_v=(vF_{out})^2\);
  • 除了注意力部分, 還有一部分標准卷積, 即前面式子中的Conv, 其參數量為:\(k^2(F_{out} - d_v)F_{in} = k^2(1 - v)F_{out}F_{in}\);
  • 所以, 忽略了相對位置嵌入和卷積偏置之后, 整體的結構的參數量約為:\(F_{in}F_{out}(2\kappa+v+v^2\frac{F_{out}}{F_{in}}+k^2-k^2v)=F_{in}F_{out}(2\kappa+v(1-k^2)+k^2+v^2\frac{F_{out}}{F_{in}})\)
  • 整體相對於卷積的參數的變化量為\(\Delta_{params}\sim F_{in}F_{out}(2\kappa+v(1-k^2)+v^2\frac{F_{out}}{F_{in}})\), 所以替換3x3卷積時, 會輕微減少參數量, 而替換1x1卷積時, 則會帶來輕微的增加.

Attention Augmented Convolutional Architectures

  • 所有實驗中, AAConv后都會跟着BN來放縮卷積層和注意力層特征圖的共享.
  • 每個殘差塊使用一次AAConv.
  • 由於QK的結果具有較大的內存占用, 所以是按照從深到淺的順序使用, 直到達到內存上限.
  • To reduce the memory footprint of augmented networks, we typically resort to a smaller batch size and sometimes additionally downsample the inputs to self-attention in the layers with the largest spatial dimensions where it is applied(這里指的應該是在注意力計算前后分別下采樣和上采樣). Downsampling is performed by applying 3x3 average pooling with stride 2 while the following upsampling (requiredfor the concatenation) is obtained via bilinear interpolation.

實驗結果

位置編碼

image.png

image.png

  • the position-unaware version of self-attention (referred to as None),
  • a two-dimensional implementation of the sinusoidal positional waves (referred to as 2d Sine) as used in [32],
  • CoordConv [29] for which we concatenate (x, y, r) coordinate channels to the inputs of the attention function,
  • our proposed two-dimensional relative position encodings (referred to as Relative).

未來的探索

  • Several open questions from this work remain. In future work, we will focus on the fully attentional regime and explore how different attention mechanisms trade off computational efficiency versus representational power. For instance, identifying a local attention mechanism may result in an efficient and scalable computational mechanism that could prevent the need for downsampling with average pooling [Stand-aloneself-attention in vision models].
  • Additionally, it is plausible that architectural design choices that are well suited when exclusively relying on convolutions are suboptimal when using self-attention mechanisms. As such, it would be interesting to see if using Attention Augmentation as a primitive in automated architecture search procedures proves useful to find even better models than those previously found in image classification [55], object detection [12], image segmentation [6] and other domains [5, 1, 35, 8].
  • Finally, one can ask to which degree fully attentional models can replace convolutional networks for visual tasks.

代碼示例

參照作者論文中的tensorflow實現, 我使用pytorch改了下.

import torch
from einops import rearrange
from torch import nn

def rel_to_abs(x):
    """
    Converts tensor from relative to aboslute indexing.
    Details can be found at: https://www.yuque.com/lart/ugkv9f/oazsec

    :param x: B Nh L 2L-1
    :return: B Nh L L
    """
    B, Nh, L, _ = x.shape

    # Pad to shift from relative to absolute indexing.
    col_pad = torch.zeros(B, Nh, L, 1)
    x = torch.cat([x, col_pad], dim=3)

    flat_x = x.reshape(B, Nh, L * 2 * L)

    flat_pad = torch.zeros(B, Nh, L - 1)
    flat_x = torch.cat([flat_x, flat_pad], dim=2)

    # Reshape and slice out the padded elements.
    final_x = flat_x.reshape(B, Nh, L + 1, 2 * L - 1)
    final_x = final_x[:, :, :L, L - 1:]
    return final_x

def relative_logits_1d(x, rel_k):
    """
    Compute relative logits along one dimenion.

    :param x: B Nh Hd L
    :param rel_k: 2L-1 Hd
    """
    rel_logits = torch.einsum("bndl, rd -> bnlr", x, rel_k)
    rel_logits = rel_to_abs(rel_logits)  # B Nh L 2L-1 -> B Nh L L
    return rel_logits

class RelativePosEmbedding(nn.Module):
    """
    Compute relative_logits.

    For ease, we 1) transpose height and width, 2) repeat the above steps and 3) transpose to eventually
    put the logits in their right positions.
    """

    def __init__(self, h, w, dim):
        super(RelativePosEmbedding, self).__init__()
        self.h = h
        self.w = w
        self.rel_emb_w = torch.randn(2 * w - 1, dim)
        nn.init.normal_(self.rel_emb_w, dim ** -0.5)
        self.rel_emb_h = torch.randn(2 * h - 1, dim)
        nn.init.normal_(self.rel_emb_h, dim ** -0.5)

    def forward(self, x):
        """
        :param x: B Nh Hd HW
        :return: B Nh HW HW
        """
        Nh = x.shape[1]
        # Relative logits in width dimension first.
        rel_logits_w = relative_logits_1d(
            rearrange(x, "b nh hd (h w) -> b (nh h) hd w", h=self.h, w=self.w), self.rel_emb_w
        )
        rel_logits_w = rearrange(rel_logits_w, "b (nh h) w0 w1 -> b nh h () w0 w1", nh=Nh)
        # Relative logits in height dimension next.
        rel_logits_h = relative_logits_1d(
            rearrange(x, "b nh hd (h w) -> b (nh w) hd h", h=self.h, w=self.w), self.rel_emb_h
        )
        rel_logits_h = rearrange(rel_logits_h, "b (nh w) h0 h1 -> b nh h0 h1 w ()", nh=Nh)
        return rearrange(rel_logits_h + rel_logits_w, "b nh h0 h1 w0 w1 -> b nh (h0 w0) (h1 w1)")

class AbsolutePosEmbedding(nn.Module):
    """
    Given query q of shape [batch heads tokens dim] we multiply
    q by all the flattened absolute differences between tokens.
    Learned embedding representations are shared across heads
    """

    def __init__(self, h, w, dim):
        super().__init__()
        scale = dim ** -0.5
        self.abs_pos_emb = nn.Parameter(torch.randn(h * w, dim) * scale)
        nn.init.normal_(self.abs_pos_emb, scale)

    def forward(self, x):
        """
        :param x: B Nh Hd HW
        :return: B Nh HW HW
        """
        return torch.einsum("bndx, yd -> bhxy", x, self.abs_pos_emb)

class SelfAttention2D(nn.Module):
    def __init__(self, in_dim, key_dim, value_dim, nh, hw, pos_mode="relative"):
        super(SelfAttention2D, self).__init__()
        self.dkh = key_dim // nh
        self.dvh = value_dim // nh
        self.nh = nh
        self.key_dim = key_dim
        self.value_dim = value_dim
        self.kqv_proj = nn.Conv2d(in_dim, 2 * key_dim + value_dim, 1)
        self.out_proj = nn.Conv2d(value_dim, value_dim, 1)
        if pos_mode == "relative":
            self.position_embedding = RelativePosEmbedding(h=hw[0], w=hw[1], dim=self.dkh)
        elif pos_mode == "absolute":
            self.position_embedding = AbsolutePosEmbedding(h=hw[0], w=hw[1], dim=self.dkh)
        else:
            self.position_embedding = nn.Identity()

    def split_heads_and_flatten(self, _x):
        return rearrange(_x, "b (nh hd) h w -> b nh hd (h w)", nh=self.nh)

    def forward(self, x):
        """
        :param x: B C H W
        """

        # Compute q, k, v
        k, q, v = self.kqv_proj(x).split([self.key_dim, self.key_dim, self.value_dim], dim=1)
        q = q * self.dkh ** -0.5  # scaled dot-product

        # After splitting, shape is [B, Nh, dkh or dvh, HW]
        q, k, v = map(self.split_heads_and_flatten, (q, k, v))

        # [B, Nh, HW, HW]
        logits = torch.einsum("bndx, bndy -> bnxy", q, k)
        logits += self.position_embedding(q)
        weights = logits.softmax(-1)
        attn_out = torch.einsum("bnxy, bndy -> bndx", weights, v)
        attn_out = rearrange(attn_out, "b nd hd (h w) -> b (nd hd) h w", h=x.shape[2], w=x.shape[3])

        # Project heads
        attn_out = self.out_proj(attn_out)
        return attn_out

class AugmentedConv2d(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, key_dim, value_dim, num_heads, hw, pos_mode):
        super(AugmentedConv2d, self).__init__()
        self.std_conv = nn.Conv2d(in_dim, out_dim - value_dim, kernel_size, padding=kernel_size // 2)
        self.attention = SelfAttention2D(
            in_dim, key_dim=key_dim, value_dim=value_dim, nh=num_heads, hw=hw, pos_mode=pos_mode
        )

    def forward(self, x):
        conv_out = self.std_conv(x)
        attn_out = self.attention(x)
        return torch.cat([conv_out, attn_out], dim=1)

if __name__ == "__main__":
    m = AugmentedConv2d(
        in_dim=4, out_dim=64, kernel_size=3, key_dim=32, value_dim=48, num_heads=2, hw=(10, 10), pos_mode="relative"
    )
    print(m(torch.randn(4, 4, 10, 10)).shape)

一些疑惑

  • permutation equivariance(置換等變性), translation equivariance(平移等變性)二者的差異是什么?

補充知識

對於self-attention包含三個輸入, query Q/key K/value V, 三者具體表示的含義是什么呢? 以下內容摘自https://www.cnblogs.com/rosyYY/p/10115424.html:

  1. Q、K、V中包含的都是原始數據的嵌入表示
  2. Q為什么叫query?
    1. 是因為每次需要拿一個嵌入表示去"查詢"其和任意的嵌入表示之間的match程度, 也就是attention大小
  3. K和V表示鍵值, 關於這里的解釋, 各處都語焉不詳, 在 從Seq2seq到Attention模型到Self Attention(二) - 量化投資機器學習的文章 - 知乎 https://zhuanlan.zhihu.com/p/47470866 中有處提到:"key、value的起源論文 Key-Value Memory Networks for Directly Reading Documents. 在NLP的領域中, Key, Value通常就是指向同一個文字隱向量(word embedding vector)". 暫且做過多解釋.

相關鏈接


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM