Octave Convolution詳解


前言

Octave Convolution來自於這篇論文Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution這篇論文,該論文也被ICCV2019接收。

Octave表示的是音階的八度,而本篇核心思想是通過對數據低頻信息減半從而達到加速卷積運算的目的,而兩個Octave之間也是聲音頻率減半【2】。

Octave Convolution(后面將以OctConv命名)主要有以下三個貢獻:

  • 將卷積特征圖分成了兩組,一組低頻,一組高頻,低頻特征圖的大小會減半,從而可以有效減少存儲以及計算量,另外,由於特征圖大小減小,卷積核大小不變,感受野就變大了,可以抓取更多的上下文信息;
  • OctConv是一種即插即用的卷積塊,可以直接替換傳統的conv(也替換分組卷積以及深度可分離卷積等),減小內存和計算量;
  • 當然,作者做了大量實驗,使用OctConv可以得到更高性能,甚至可以媲美best Auto ML。

總的來說,OctConv是占用內存小,速度快,性能高,即插即用的conv塊。

OctConv的特征表示

自然圖像可被分解為低頻分量以及高頻分量,如下所示:

而卷積層的特征圖也可以分為低頻和高頻分量,如下圖(b)所示,OctConv卷積的做法是將低頻分量的空間分辨率減半(如下圖c所示),然后分兩組進行conv,兩組頻率之間會通過上采樣和下采樣進行信息交互(見下圖d),最后再合成原始特征圖大小。

作者認為低頻分量在一些特征圖中是富集的,可以被壓縮的,所以對低頻分量進行了壓縮,壓縮的方式沒有采用stride conv,而是使用了average pooling,因為stride conv會導致不對齊的行為。

OctConv的詳細過程

如上圖所示,OctConv的輸入有兩部分,一部分是高頻\(X^H\),另一部分是低頻\(X^L\),觀察到\(X^L\)的大小是\(X^H\)的二分之一,這里通過兩個參數\(\alpha_{in}\)\(\alpha_{out}\)來控制低高頻的輸入通道和輸出通道,一開始,輸入只有一個\(X\),這時候的\(\alpha_{in}\)為0,然后通過兩個卷積層(\(f\left(X^{H} ; W^{H \rightarrow H}\right)\)\(f\left(p o o l\left(X^{H}, 2\right) ; W^{H \rightarrow L}\right)\))得到高頻分量和低頻分量,中間的OctConv就是有兩個輸入了和兩個輸出,最后需要從兩個輸入恢復出一個輸出,此時\(\alpha_{out}\)為0,通過\(f\left(X^{H} ; W^{H \rightarrow H}\right)\)\(upsample\left(f\left(X^{L} ; W^{L \rightarrow H}\right), 2\right)\)兩個操作得到單獨輸出。

現在來討論上面的四根線上的操作各代表什么。

\(f\left(X^{H} ; W^{H \rightarrow H}\right)\)是高頻信息到高頻信息,通過一個卷積層即可。

\(f\left(p o o l\left(X^{H}, 2\right) ; W^{H \rightarrow L}\right)\)是將高頻信息匯合到低頻信息中,先通過一個平均池化,然后通過一個卷積層。

\(upsample\left(f\left(X^{L} ; W^{L \rightarrow H}\right), 2\right)\)是將低頻信息匯合到高頻信息,先通過一個卷積層,然后通過平均池化層。

\(f\left(X^{L} ; W^{L \rightarrow L}\right)\)是將低頻信息到低頻信息,通過一個卷積層。

現在看看卷積核參數分配的問題,如下圖所示:

上面的四個操作對應上圖的四個部分,可以看到總的參數依然是\(c_{i n} \times c_{o u t} \times k \times k\),但由於低頻分量的尺寸減半,所需要的存儲空間變小,以及計算量縮減,達到加速卷積的過程。

Pytorch代碼

下面的代碼來自於OctaveConv_pytorch ,代碼可讀性很高,如果理解了上述過程,看起來會很容易。

第一層OctConv卷積,將特征圖x分為高頻和低頻:

class FirstOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels,kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(FirstOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.h2l = torch.nn.Conv2d(in_channels, int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels, out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        if self.stride ==2:
            x = self.h2g_pool(x)

        X_h2l = self.h2g_pool(x)
        X_h = x
        X_h = self.h2h(X_h)
        X_l = self.h2l(X_h2l)

        return X_h, X_l

中間層的OctConv,低高頻輸入,低高頻輸出:

class OctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(OctaveConv, self).__init__()
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
        self.stride = stride
        self.l2l = torch.nn.Conv2d(int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2l = torch.nn.Conv2d(in_channels - int(alpha * in_channels), int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels),
                                   out_channels - int(alpha * out_channels),
                                   kernel_size, 1, padding, dilation, groups, bias)

    def forward(self, x):
        X_h, X_l = x

        if self.stride ==2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_h2l = self.h2g_pool(X_h)

        X_h2h = self.h2h(X_h)
        X_l2h = self.l2h(X_l)

        X_l2l = self.l2l(X_l)
        X_h2l = self.h2l(X_h2l)
        
        X_l2h = self.upsample(X_l2h)
        X_h = X_l2h + X_h2h
        X_l = X_h2l + X_l2l

        return X_h, X_l

最后一層的OctConv,將低高頻匯合稱輸出。

class LastOctaveConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
                 groups=1, bias=False):
        super(LastOctaveConv, self).__init__()
        self.stride = stride
        kernel_size = kernel_size[0]
        self.h2g_pool = nn.AvgPool2d(kernel_size=(2,2), stride=2)

        self.l2h = torch.nn.Conv2d(int(alpha * in_channels), out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.h2h = torch.nn.Conv2d(in_channels - int(alpha * in_channels),
                                   out_channels,
                                   kernel_size, 1, padding, dilation, groups, bias)
        self.upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        X_h, X_l = x

        if self.stride ==2:
            X_h, X_l = self.h2g_pool(X_h), self.h2g_pool(X_l)

        X_l2h = self.l2h(X_l)
        X_h2h = self.h2h(X_h)
        X_l2h = self.upsample(X_l2h)
        
        X_h = X_h2h + X_l2h

        return X_h

參考

【1】 Octave Convolution論文

【2】Pytorch代碼

【3】Octave Convolution博客


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM