[卷積核]空洞卷積(轉)


轉自: https://www.cnblogs.com/hellcat/p/9687624.html

 


一、空洞卷積的提出

空洞卷積(atrous convolutions)又名擴張卷積(dilated convolutions),向卷積層引入了一個稱為 “擴張率(dilation rate)”的新參數,該參數定義了卷積核處理數據時各值的間距。

該結構的目的是在不用pooling(pooling層會導致信息損失)且計算量相當的情況下,提供更大的感受野。 順便一提,卷積結構的主要問題如下:

池化層不可學

內部數據結構丟失;空間層級化信息丟失。

小物體信息無法重建 (假設有四個pooling layer 則 任何小於 2^4 = 16 pixel 的物體信息將理論上無法重建。)

而空洞卷積就有內部數據結構的保留和避免使用 down-sampling 這樣的特性,優點明顯。

二、空洞卷積原理

如下如,卷積核沒有紅點標記位置為0,紅點標記位置同正常卷積核。

假設原始特征為feat0,首先使用擴張率為1的空洞卷積生成feat1,feat1上一點相對feat0感受野為3*3(如圖a);

然后使用擴張率為2的空洞卷積處理feat1生成feat2(如圖b),使第一次空洞卷積的卷積核大小等於第二次空洞卷積的一個像素點的感受野,圖b即feat1上一個點綜合了圖a即feat0上3*3區域的信息,則生成的feat2感受野為7*7,即整個圖b深色區域;

第三次處理同上,第二次空洞卷積的整個卷積核大小等於第三次空洞卷積的一個像素點的感受野,圖c即feat2上每個點綜合了feat0上7*7的信息(感受野),則采用擴張率為3的空洞卷積,生成的feat3每一個點感受野為15*15。

相比較之下,使用stride為1的普通3*3卷積,三層之后感受野僅僅為(kernel-1)*layer+1=7

三、空洞卷積問題

感受野跳躍

我們對同一張圖連續三次使用擴張率為1的空洞卷積,觀察整張圖的中心點的感受野(如下圖)

很明顯,感受野不連續(我們上一小結的例子就沒這個問題,所以空洞卷積依賴網絡設計)。

小尺度物體檢測

類似第一個問題,仍然需要調整擴張率的組合來解決這個問題。

四、網絡設計研究

第一個特性是,疊加卷積的 dilation rate 不能有大於1的公約數。比如 [2, 4, 6] 則不是一個好的三層卷積,依然會出現 gridding effect。

第二個特性是,我們將 dilation rate 設計成 鋸齒狀結構,例如 [1, 2, 5, 1, 2, 5] 循環結構。

第三個特性是,我們需要滿足一下這個式子: M_i=\max[M_{i+1}-2r_i,M_{i+1}-2(M_{i+1}-r_i),r_i]

其中 r_i 是 i 層的 dilation rate 而 M_i 是指在 i 層的最大dilation rate,那么假設總共有n層的話,默認 M_n=r_n 。假設我們應用於 kernel 為 k x k 的話,我們的目標則是 M_2 \leq k ,這樣我們至少可以用 dilation rate 1 即 standard convolution 的方式來覆蓋掉所有洞。

一個簡單的例子: dilation rate [1, 2, 5] with 3 x 3 kernel (可行的方案)

而這樣的鋸齒狀本身的性質就比較好的來同時滿足小物體大物體的分割要求(小 dilation rate 來關心近距離信息,大 dilation rate 來關心遠距離信息)。

單分支設計的研究

通向標准化設計:Hybrid Dilated Convolution (HDC),可以很好的滿足分割需要,如下圖所示,

 

多分支研究解決多尺度分割

僅僅(在一個卷積分支網絡下)使用 dilated convolution 去抓取多尺度物體是一個不正統的方法。比方說,我們用一個 HDC 的方法來獲取一個大(近)車輛的信息,然而對於一個小(遠)車輛的信息都不再受用。假設我們再去用小 dilated convolution 的方法重新獲取小車輛的信息,則這么做非常的冗余。

基於港中文和商湯組的 PSPNet 里的 Pooling module (其網絡同樣獲得當年的SOTA結果),ASPP 則在網絡 decoder 上對於不同尺度上用不同大小的 dilation rate 來抓去多尺度信息,每個尺度則為一個獨立的分支,在網絡最后把他合並起來再接一個卷積層輸出預測 label。這樣的設計則有效避免了在 encoder 上冗余的信息的獲取,直接關注與物體之間之內的相關性。

 

五、常用框架API介紹

TensorFlow接口

tf.nn.atrous_conv2d(value, filters, rate, padding, name=None)

value: 指需要做卷積的輸入圖像,要求是一個4維Tensor,具有[batch, height, width, channels]這樣的shape,具體含義是[訓練時一個batch的圖片數量, 圖片高度, 圖片寬度, 圖像通道數]

filters: 相當於CNN中的卷積核,要求是一個4維Tensor,具有[filter_height, filter_width, channels, out_channels]這樣的shape,具體含義是[卷積核的高度,卷積核的寬度,圖像通道數,卷積核個數],同理這里第三維channels,就是參數value的第四維

rate: 要求是一個int型的正數,正常的卷積操作應該會有stride(即卷積核的滑動步長),但是空洞卷積是沒有stride參數的,這一點尤其要注意。取而代之,它使用了新的rate參數,那么rate參數有什么用呢?它定義為我們在輸入圖像上卷積時的采樣間隔,你可以理解為卷積核當中穿插了(rate-1)數量的“0”,把原來的卷積核插出了很多“洞洞”,這樣做卷積時就相當於對原圖像的采樣間隔變大了。具體怎么插得,可以看后面更加詳細的描述。此時我們很容易得出rate=1時,就沒有0插入,此時這個函數就變成了普通卷積。

padding: string類型的量,只能是”SAME”,”VALID”其中之一,這個值決定了不同邊緣填充方式。

函數默認stride=1,無法改變。

結果返回一個Tensor,填充方式為“VALID”時,返回[batch,height-2*(filter_width-1),width-2*(filter_height-1),out_channels]的Tensor,填充方式為“SAME”時,返回[batch, height, width, out_channels]的Tensor。

測試代碼如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
img = tf.constant(value = [[[[ 1 ],[ 2 ],[ 3 ],[ 4 ]],
                           [[ 1 ],[ 2 ],[ 3 ],[ 4 ]],
                           [[ 1 ],[ 2 ],[ 3 ],[ 4 ]],
                           [[ 1 ],[ 2 ],[ 3 ],[ 4 ]]]],dtype = tf.float32)
img = tf.concat(values = [img,img],axis = 3 )
 
filter = tf.constant(value = 1 , shape = [ 3 , 3 , 2 , 5 ], dtype = tf.float32)
out_img1 = tf.nn.atrous_conv2d(value = img, filters = filter , rate = 1 , padding = 'SAME' )
out_img2 = tf.nn.atrous_conv2d(value = img, filters = filter , rate = 1 , padding = 'VALID' )
out_img3 = tf.nn.atrous_conv2d(value = img, filters = filter , rate = 2 , padding = 'SAME' )
#error
#out_img4 = tf.nn.atrous_conv2d(value=img, filters=filter, rate=2, padding='VALID')
with tf.Session() as sess:
     print ( 'rate=1, SAME mode result:' )
     print (sess.run(out_img1))
     print ( 'rate=1, VALID mode result:' )
     print (sess.run(out_img2))
     print ( 'rate=2, SAME mode result:' )
     print (sess.run(out_img3)) # error #print 'rate=2, VALID mode result:' #print(sess.run(out_img4))

擴張率為1時,空洞卷積等價於普通卷積。對於SAME和VALID模式計算方式如下圖所示,

擴張率為2的VALID模式計算過程,

擴張率為2的VALID模式會報錯,此時卷積核大於圖片,無法卷積。

MXNet接口

MXNet卷積操作自帶擴張率參數,詳見文檔

MXNet的通道存儲與TensorFlow不太一致,所以我們打印一下(對比上面的圖,可以體會到為什么除了tf外大多框架把通道放在第二維),

?
1
2
3
4
5
6
7
8
9
10
11
12
13
import  mxnet as mx
import mxnet.ndarray as nd
 
img = nd.array([[[[ 1 ],[ 2 ],[ 3 ],[ 4 ]],
                 [[ 1 ],[ 2 ],[ 3 ],[ 4 ]],
                 [[ 1 ],[ 2 ],[ 3 ],[ 4 ]],
                 [[ 1 ],[ 2 ],[ 3 ],[ 4 ]]]])
img = nd.concat(img, img, dim = - 1 )
img = nd.transpose(img, axes = ( 0 , 3 , 1 , 2 ))
 
w = nd.ones([ 5 , 2 , 3 , 3 ])
b = nd.array([ 0 for _ in range ( 5 )])
img
[[[[1. 2. 3. 4.]
   [1. 2. 3. 4.]
   [1. 2. 3. 4.]
   [1. 2. 3. 4.]]

[[1. 2. 3. 4.]
[1. 2. 3. 4.]
[1. 2. 3. 4.]
[1. 2. 3. 4.]]]]
<NDArray 1x2x4x4 @cpu(0)>

?
1
nd.Convolution(img, w, b, kernel = w.shape[ 2 :], num_filter = w.shape[ 0 ], stride = ( 1 , 1 ), pad = ( 1 , 1 ), dilate = ( 1 , 1 ))
[[[[12. 24. 36. 28.]
   [18. 36. 54. 42.]
   [18. 36. 54. 42.]
   [12. 24. 36. 28.]]

[[12. 24. 36. 28.]
[18. 36. 54. 42.]
[18. 36. 54. 42.]
[12. 24. 36. 28.]]

[[12. 24. 36. 28.]
[18. 36. 54. 42.]
[18. 36. 54. 42.]
[12. 24. 36. 28.]]

[[12. 24. 36. 28.]
[18. 36. 54. 42.]
[18. 36. 54. 42.]
[12. 24. 36. 28.]]

[[12. 24. 36. 28.]
[18. 36. 54. 42.]
[18. 36. 54. 42.]
[12. 24. 36. 28.]]]]
<NDArray 1x5x4x4 @cpu(0)>

?
1
nd.Convolution(img, w, b, kernel = w.shape[ 2 :], num_filter = w.shape[ 0 ], stride = ( 1 , 1 ), pad = ( 2 , 2 ), dilate = ( 2 , 2 ))
[[[[16. 24. 16. 24.]
   [16. 24. 16. 24.]
   [16. 24. 16. 24.]
   [16. 24. 16. 24.]]

[[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]]

[[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]]

[[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]]

[[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]
[16. 24. 16. 24.]]]]
<NDArray 1x5x4x4 @cpu(0)>

六、參考來源

Multi-scale Context Aggregation by Dilated Convolutions

【Tensorflow】tf.nn.atrous_conv2d如何實現空洞卷積?

如何理解空洞卷積(dilated convolution)?

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM