【CV中的Attention機制】Non-Local Network的理解與實現


1. Non-local

Non-Local是王小龍在CVPR2018年提出的一個自注意力模型。Non-Local Neural Network和Non-Local Means非局部均值去燥濾波有點相似的感覺。普通的濾波都是3×3的卷積核,然后在整個圖片上進行移動,處理的是3×3局部的信息。Non-Local Means操作則是結合了一個比較大的搜索范圍,並進行加權。

在Non-Local NN這篇文章中的Local也與以上有一定關系,主要是針對感受野來說的,一般的卷積的感受野都是3×3或5×5的大小,而使用Non-Local可以讓感受野很大,而不是局限於一個局部領域。

與之前介紹的CBAM模塊,SE模塊,BAM模塊,SK模塊類似,Non-Local也是一個易於集成的模塊,針對一個feature map進行信息的refine, 也是一種比較好的attention機制的實現。不過相比前幾種attention模塊,Non-Local中的attention擁有更多地理論支撐,稍微有點晦澀難懂。

Non-local的通用公式表示:

\[y_i=\frac{1}{C(x)}\sum_{\forall j}f(x_i,x_j)g(x_j) \]

  • x是輸入信號,cv中使用的一般是feature map
  • i 代表的是輸出位置,如空間、時間或者時空的索引,他的響應應該對j進行枚舉然后計算得到的
  • f 函數式計算i和j的相似度
  • g 函數計算feature map在j位置的表示
  • 最終的y是通過響應因子C(x) 進行標准化處理以后得到的

理解:與Non local mean相比,就很容易理解,i 代表的是當前位置的響應,j 代表全局響應,通過加權得到一個非局部的響應值。

Non-Local的優點是什么?

  • 提出的non-local operations通過計算任意兩個位置之間的交互直接捕捉遠程依賴,而不用局限於相鄰點,其相當於構造了一個和特征圖譜尺寸一樣大的卷積核, 從而可以維持更多信息。
  • non-local可以作為一個組件,和其它網絡結構結合,經過作者實驗,證明了其可以應用於圖像分類、目標檢測、目標分割、姿態識別等視覺任務中,並且效果有不同程度的提升。
  • Non-local在視頻分類上效果很好,在視頻分類的任務中效果可觀。

2. 細節

論文中給了通用公式,然后分別介紹f函數g函數的實例化表示:

g函數:可以看做一個線性轉化(Linear Embedding)公式如下:

\[g(x_j)=W_gx_j \]

\(W_g​\) 是需要學習的權重矩陣,可以通過空間上的1×1卷積實現(實現起來比較簡單)。


f函數:這是一個用於計算i和j相似度的函數,作者提出了四個具體的函數可以用作f函數。

  • Gaussian function: 具體公式如下:

\[f(x_i,x_j)=e^{x_i^Tx_j} \\ C(x)=\sum_{\forall j}f(x_i,x_j) \]

這里使用的是 \(x_i^Tx_j\) 一個點乘來計算相似度,之所以點積可以衡量相似度,這是通過余弦相似度簡化而來的。

\[\vec a *\vec b = |\vec a||\vec b|cos \theta \]

  • Embedded Gaussian: 具體公式如下:

\[f(x_i,x_j)=e^{\theta(x_i)^T\phi(x_j)} \\ C(x)=\sum_{\forall j}f(x_i,x_j) \]

  • Dot product: 具體公式如下:

\[f(x_i,x_j)=\theta(x_i)^T\phi(x_j) \\ C(x)=|\{i|i is a valid index of x\}| \]

  • Concatenation: 具體公式如下:

\[f(x_i,x_j)=ReLU(w_f^T .[\theta(x_i),\phi(x_j)]) \\ C(x)=|\{i|i is a valid index of x\}| \]


以上四個函數可能看起來感覺讓人讀起來很吃力,下邊進行大概解釋一下上邊符號的意義,結合示意圖(以Embeded Gaussian為例,對原圖進行細節上加工,具體參見代碼,地址為文末鏈接中的non_local_embedded_gaussian.py文件):

  • x代表feature map, \(x_i\) 代表的是當前關注位置的信息; \(x_j\) 代表的是全局信息。

  • θ代表的是 \(\theta (x_i)=W_{\theta}x_i​\) ,實際操作是用一個1×1卷積進行學習的。

  • φ代表的是 \(\phi (x_j)=W_{\phi}x_j\),實際操作是用一個1×1卷積進行學習的。

  • g函數意義同上。

  • C(x)代表的是歸一化操作,在embedding gaussian中使用的是Sigmoid實現的。

然后可以將上圖(實現角度)與下圖(比較抽象)進行結合理解:

具體解釋如下:(ps: 以下解釋帶上了bs,上圖中由於bs不方便畫圖,所以沒有添加bs)

X是一個feature map,形狀為[bs, c, h, w], 經過三個1×1卷積核,將通道縮減為原來一半(c/2)。然后將h,w兩個維度進行flatten,變為h×w,最終形狀為[bs, c/2, h×w]的tensor。對θ對應的tensor進行通道重排,在線性代數中也就是轉置,得到形狀為[bs, h×w, c/2]。然后與φ代表的tensor進行矩陣乘法,得到一個形狀為[bs, h×w,h×w]的矩陣,這個矩陣計算的是相似度(或者理解為attention)。然后經過softmax進行歸一化,然后將該得到的矩陣 \(f_c\) 與g 經過flatten和轉置的結果進行矩陣相乘,得到的形狀為[bs, h*w, c/2]的結果y。然后轉置為[bs, c/2, h×w]的tensor, 然后將h×w維度重新伸展為[h, w],從而得到了形狀為[bs, c/2, h, w]的tensor。然后對這個tensor再使用一個1×1卷積核,將通道擴展為原來的c,這樣得到了[bs, c, h, w]的tensor,與初始X的形狀是一致的。最終一步操作是將X與得到的tensor進行相加(類似resnet中的residual block)。

可能存在的問題

計算量偏大:在高階語義層引入non local layer, 也可以在具體實現的過程中添加pooling層來進一步減少計算量。

3. 代碼

代碼來自官方,修改了一點點以便於理解,推薦將代碼的forward部分與上圖進行對照理解。

import torch
from torch import nn
from torch.nn import functional as F


class _NonLocalBlockND(nn.Module):
    """
    調用過程
    NONLocalBlock2D(in_channels=32),
    super(NONLocalBlock2D, self).__init__(in_channels,
            inter_channels=inter_channels,
            dimension=2, sub_sample=sub_sample,
            bn_layer=bn_layer)
    """
    def __init__(self,
                 in_channels,
                 inter_channels=None,
                 dimension=3,
                 sub_sample=True,
                 bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            # 進行壓縮得到channel個數
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels,
                         out_channels=self.inter_channels,
                         kernel_size=1,
                         stride=1,
                         padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels,
                        out_channels=self.in_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0), bn(self.in_channels))
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels,
                             out_channels=self.in_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels,
                             out_channels=self.inter_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
        self.phi = conv_nd(in_channels=self.in_channels,
                           out_channels=self.inter_channels,
                           kernel_size=1,
                           stride=1,
                           padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        '''
        :param x: (b, c,  h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)#[bs, c, w*h]
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        
        f = torch.matmul(theta_x, phi_x)

        print(f.shape)

        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x
        return z

4. 實驗結論

  • 文中提出了四個計算相似度的模型,實驗對四個方法都進行了實驗,發現了這四個模型效果相差並不大,於是有一個結論:使用non-local對baseline結果是有提升的,但是不同相似度計算方法之間差距並不大,所以可以采用其中一個做實驗即可,文中用embedding gaussian作為默認的相似度計算方法。

  • 作者做了一系列消融實驗來證明non local NN的有效性:

  1. 使用四個相似度計算模型,發現影響不大,但是都比baseline效果好。

  1. 以ResNet50為例,測試加在不同stage下的結果。可以看出在res2,3,4部分得到的結果相對baseline提升比較大,但是res5就一般了,這有可能是由於第5個stage中的feature map的spatial size比較小,信息比較少,所以提升比較小。

  1. 嘗試添加不同數量的non local block ,結果如下。可以發現,添加越多的non local 模塊,其效果越好,但是與此同時帶來的計算量也會比較大,所以要對速度和精度進行權衡。

  1. Non-local 與3D卷積的對比,發現要比3D卷積計算量小的情況下,准確率有較為可觀的提升。

  1. 作者還將Non-local block應用在目標檢測、實例分割、關鍵點檢測等領域。可以將non-local block作為一個trick添加到目標檢測、實例分割、關鍵點檢測等領域, 可能帶來1-3%的提升。

5. 評價

Non local NN從傳統方法Non local means中獲得靈感,然后接着在神經網絡中應用了這個思想,直接融合了全局的信息,而不僅僅是通過堆疊多個卷積層獲得較為全局的信息。這樣可以為后邊的層帶來更為豐富的語義信息。

論文中也通過消融實驗,完全證明了該模塊在視頻分類,目標檢測,實例分割、關鍵點檢測等領域的有效性,但是其中並沒有給出其帶來的參數量上的變化,或者計算速度的變化。但是可以猜得到,參數量的增加還是有一定的,如果對速度有要求的實驗可能要進行速度和精度上的權衡,不能盲目添加non local block。神經網絡中還有一個常見的操作也是利用的全局信息,那就是Linear層,全連接層將feature map上每一個點的信息都進行了融合,Linear可以看做一種特殊的Non local操作。

之后GCNet等工作對Non-Local Neural Network結構進行改進,能夠大幅降低Non-Local NN的計算量,更具有實用價值。

6. 參考內容

論文:https://arxiv.org/abs/1711.07971

video classification 代碼:https://github.com/facebookresearch/video-nonlocal-net

non local官方實現:https://github.com/pprp/SimpleCVReproduction/tree/master/attention/Non-local/Non-Local_pytorch_0.4.1_to_1.1.0/lib

知乎文章:https://zhuanlan.zhihu.com/p/33345791

博客:https://hellozhaozheng.github.io/z_post/計算機視覺-NonLocal-CVPR2018/


推薦閱讀:

CV中的Attention機制-最簡單最易實現的SE模塊

CV中的Attention機制-Selective-Kernel-Networks-SE進化版

CV中的Attention機制-CBAM模塊

CV中的Attention機制-並行版的CBAM-BAM模塊

CV中的attention機制-語義分割中的scSE模塊


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM