膨脹卷積


膨脹卷積 Dilated Convolution

也叫空洞卷積 Atrous Convolution

膨脹系數dilation rate \(r=1\)時就是普通卷積,上圖中的膨脹系數\(r=2\)

為什么要引入膨脹卷積?

因為maxpooling進行池化操作后,一些細節和小目標會丟失,在之后的上采樣中無法還原這些信息,造成小目標檢測准確率降低
然而去掉池化層又會減小感受野

膨脹卷積的作用

  • 增大感受野
  • 保持輸入特征的寬高

膨脹卷積的計算

For a convolution kernel with size \(k \times k\), the size of resulting dilated filter is \(k_d \times k_d\), where \(k_d = k + (k − 1) \cdot (r − 1)\)
比如下圖的最左側,\(r=2,k=3\)時,一個卷積核的總覆蓋面積是\(5 \times 5\)
這里要注意實際的感受野其實就是\(k \times k\),因為插了0,實際參與到累加的pixel只有\(k \times k\)

Since dilated convolution introduces zeros in the convolutional kernel, the actual pixels that participate in the computation from the \(k_d \times k_d\) region are just \(k \times k\), with a gap of \(r − 1\) between them.

膨脹卷積的問題

柵格效應Gridding Effect

用上圖解釋一下,如果經過三次膨脹系數\(r=2\),kernel size=\(3\times 3\)的膨脹卷積
第二層的一個pixel用到了第一層的9個pixel,如最左圖所示,接着第三層的一個pixel就用到了第二層的9個pixel,也就相當於第一層的25個pixel,如中間圖所示
顏色代表通過累加,第一層的各個pixel被當前層利用到的次數,次數越高顏色越深
最后就是下圖效果

從上圖可以看出Dilated Convolution的kernel並不連續,也就是並不是所有的像素都用來計算了,因此這里將信息看作checker-board的方式將會損失信息的連續性
而如果第一次采用普通卷積的話就不會丟失底層信息,接着采用不同的膨脹系數,感受野跟上一種方法是一樣大的卻能利用到更多的信息

再與普通卷積對比一下,普通卷積雖然利用了連續的區域,但相比之下感受野就小了很多

Hybrid Dilated Convolution(HDC)

混合卷積論文鏈接

非零元素最大距離

當使用到多個膨脹卷積時,需要設計各卷積核的膨脹系數使其剛好能覆蓋底層特征層

The goal of HDC is to let the final size of the RF of a series of convolutional operations fully covers a square region without any holes or missing edges.

論文中提到了maximum distance between two nonzero values,通過這個來限制膨脹系數的大小

下圖分別為\(r=\)[1,2,5]和[1,2,9]的情況

鋸齒結構

dilated rate設計成了鋸齒狀結構,例如[1, 2, 5, 1, 2, 5]這樣的循環結構

公約數不能大於1

疊加的膨脹卷積的膨脹率dilated rate不能有大於1的公約數(比如[2, 4, 8]),不然會產生柵格效應

畫圖代碼

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


def dilated_conv_one_pixel(center: (int, int),
                           feature_map: np.ndarray,
                           k: int = 3,
                           r: int = 1,
                           v: int = 1):
    """
    膨脹卷積核中心在指定坐標center處時,統計哪些像素被利用到,
    並在利用到的像素位置處加上增量v
    Args:
        center: 膨脹卷積核中心的坐標
        feature_map: 記錄每個像素使用次數的特征圖
        k: 膨脹卷積核的kernel大小
        r: 膨脹卷積的dilation rate
        v: 使用次數增量
    """
    assert divmod(3, 2)[1] == 1

    # left-top: (x, y)
    left_top = (center[0] - ((k - 1) // 2) * r, center[1] - ((k - 1) // 2) * r)
    for i in range(k):
        for j in range(k):
            feature_map[left_top[1] + i * r][left_top[0] + j * r] += v


def dilated_conv_all_map(dilated_map: np.ndarray,
                         k: int = 3,
                         r: int = 1):
    """
    根據輸出特征矩陣中哪些像素被使用以及使用次數,
    配合膨脹卷積k和r計算輸入特征矩陣哪些像素被使用以及使用次數
    Args:
        dilated_map: 記錄輸出特征矩陣中每個像素被使用次數的特征圖
        k: 膨脹卷積核的kernel大小
        r: 膨脹卷積的dilation rate
    """
    new_map = np.zeros_like(dilated_map)
    for i in range(dilated_map.shape[0]):
        for j in range(dilated_map.shape[1]):
            if dilated_map[i][j] > 0:
                dilated_conv_one_pixel((j, i), new_map, k=k, r=r, v=dilated_map[i][j])

    return new_map


def plot_map(matrix: np.ndarray):
    plt.figure()

    c_list = ['white', 'blue', 'red']
    new_cmp = LinearSegmentedColormap.from_list('chaos', c_list)
    plt.imshow(matrix, cmap=new_cmp)

    ax = plt.gca()
    ax.set_xticks(np.arange(-0.5, matrix.shape[1], 1), minor=True)
    ax.set_yticks(np.arange(-0.5, matrix.shape[0], 1), minor=True)

    # 顯示color bar
    plt.colorbar()

    # 在圖中標注數量
    thresh = 5
    for x in range(matrix.shape[1]):
        for y in range(matrix.shape[0]):
            # 注意這里的matrix[y, x]不是matrix[x, y]
            info = int(matrix[y, x])
            ax.text(x, y, info,
                    verticalalignment='center',
                    horizontalalignment='center',
                    color="white" if info > thresh else "black")
    ax.grid(which='minor', color='black', linestyle='-', linewidth=1.5)
    plt.show()
    plt.close()

def main():
    # bottom to top
    dilated_rates = [1, 2, 3]
    # init feature map
    size = 31
    m = np.zeros(shape=(size, size), dtype=np.int32)
    center = size // 2
    m[center][center] = 1
    # print(m)
    # plot_map(m)

    for index, dilated_r in enumerate(dilated_rates[::-1]):
        new_map = dilated_conv_all_map(m, r=dilated_r)
        m = new_map
    print(m)
    plot_map(m)

if __name__ == '__main__':
    main()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM