【CV中的Attention機制】Selective Kernel Networks(SE進化版)


1. SKNet

SKNet是SENet的加強版,結合了SE opetator, Merge-and-Run Mappings以及attention on inception block的產物。其最終提出的也是與SE類似的一個模塊,名為SK, 可以自適應調節自身的感受野。據作者說,該模塊在超分辨率任務上有很大提升,並且論文中的實驗也證實了在分類任務上有很好的表現。

這篇博客重畫了SK模塊示意圖,詳見下圖,下圖中上邊的部分是重畫的,下邊的是論文中的圖,雖然比較簡潔,但是比較難理解。上邊重畫的部分分為了三個部分,而原來的模塊分成了兩個模塊。

接下來對照着圖先理一遍思路,然后再直接上pytorch版本的代碼。

論文中說這個模塊可以更好地實現多個分辨率,調節感受野,個人理解就是從不同的分支造成的。下邊講解對照上圖進行:

原始feature map X 經過kernel size分別為3×3,5×5....以此類推的卷積進行卷積后得到U1,U2,U3三個,然后相加得到了U,相當於融合了多個感受野的信息。然后得到的U是C×H×W的(C代表channel,H代表height, W代表width)feature map,然后將H和W維度求平均值,具體做法是使用torch.mean完成,最終得到了關於channel的信息是一個C×1×1的一維向量,代表的是各個通道的信息的重要程度。

之后再用了一個線性變換,將原來的C維映射成z維度的信息,進行信息抽取,然后分別使用了三個線性變換,從z維度變為原來的c維度,這樣完成了針對channel維度的信息提取,然后使用Softmax進行歸一化,這時候每個channel對應一個分數,代表其channel的重要程度,這相當於一個打分mask。將這三個分別得到的mask分別乘以對應的U1,U2,U3,得到A1,A2,A3, 然后相加三個模塊,進行信息融合,得到最終模塊A, 模塊A相比於最初的X經過了信息的提純,具有了多尺度的信息。

經過以上分析,就能理解了作者的SK模塊的構成了:

  • 從C線性變換為Z維,再到C維度,這個部分與SE operator比較像
  • 多分支的操作借鑒自:inception
  • 整個流程類似merge-and-run mapping

這就是merge-and-run mapping中提出的三個基礎模塊,與本文sk雖然沒有直接聯系,但是都是屬於先進行分支,然后在合並,也類似於inception中的圖。

2. pytorch代碼

import torch.nn as nn
import torch

class SKConv(nn.Module):
    def __init__(self, features, WH, M, G, r, stride=1, L=32):
        """ Constructor
        Args:
            features: input channel dimensionality.
            WH: input spatial dimensionality, used for GAP kernel size.
            M: the number of branchs.
            G: num of convolution groups.
            r: the radio for compute d, the length of z.
            stride: stride, default 1.
            L: the minimum dim of the vector z in paper, default 32.
        """
        super(SKConv, self).__init__()
        d = max(int(features / r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        for i in range(M):
            self.convs.append(
                nn.Sequential(
                    nn.Conv2d(features,
                              features,
                              kernel_size=3 + i * 2,
                              stride=stride,
                              padding=1 + i,
                              groups=G), nn.BatchNorm2d(features),
                    nn.ReLU(inplace=False)))
        # self.gap = nn.AvgPool2d(int(WH/stride))
        print("D:", d)
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(nn.Linear(d, features))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        for i, conv in enumerate(self.convs):
            fea = conv(x).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)
        fea_U = torch.sum(feas, dim=1)
        # fea_s = self.gap(fea_U).squeeze_()
        fea_s = fea_U.mean(-1).mean(-1)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            print(i, fea_z.shape)
            vector = fc(fea_z).unsqueeze_(dim=1)
            print(i, vector.shape)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector],
                                              dim=1)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v

if __name__ == "__main__":
    t = torch.ones((32, 256, 24,24))
    sk = SKConv(256,WH=1,M=2,G=1,r=2)
    out = sk(t)
    print(out.shape)

3. 資源

sknet論文地址:https://arxiv.org/pdf/1903.06586.pdf

作者知乎講解:https://zhuanlan.zhihu.com/p/59690223

代碼源自:https://github.com/implus/SKNet


畫圖、碼字不易,求個關注


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM