Convolution with even-sized kernels and symmetric padding
Intro
本文探究了偶數kernel size的卷積對網絡的影響,結果表明偶數卷積在結果上並不如奇數卷積。文章從實驗與原理上得出結論,偶數卷積之所以結果更差,是因為偶數卷積會使得feature map偏移,即“the shift problem”,這將導致學習到的feature不具備更強的表征能力。本文提出信息侵蝕假設,認為奇數卷積中心對稱,而偶數卷積在實現時沒有對稱點,這將導致在實現時卷積利用的信息不能是各個方向的,只能是左上或其他方向(取決於具體實現),因此整體將會導致feature map往一個方向偏移。為了解決這個問題,文章提出symmetric padding方法來彌補各個方向帶來的損失,結果提示很明顯。
The shift problem
奇數kernel size的卷積實現起來很容易,在對應位置計算當前位置和其八個鄰域方向位置的feature值與權重求和即對應輸出位置一個值,而偶數kernel size的卷積要如何實現呢?在tensorflow里,偶數kernel size的卷積如2×2卷積是利用當前點和其左上、上方、左方一共四個點與對應權值相乘求和得到的,正因如此,在實現過程,輸出feature map的感受野其實是有缺陷的,他只能對應與其左上方的區域,多層卷積之后暴露出來的問題就是feature map的偏移。
如上圖所示,第一行是沒使用本文方法padding的conv2x2的結果,第二行是本文方法的結果。
輸出的feature map和輸入feature map的關系可以表述為:
其中p表示位置坐標,\(F^i\)和\(F^o\)分別表示輸入feature和輸出feature,\(\delta\)表示卷積核內的位置,\(\omega_i\)表示卷積核的權重。
對於奇數kernel size卷積,其中\(\mathcal{R}\)可以寫成:
以kernel的中心為原點,各個方向的偏移值就是上面的表示。對於奇數卷積,顯然上式是中心對稱的,而對於偶數kernel size卷積,對應的\(\mathcal{R}\)定義為:
上式並沒有利用各個方向的信息,並且卷積核並不是中心對稱的。
The information erosion hypothesis
對於輸入位置p,經過n次偶數卷積之后對應的位置為:
因而網絡越深,shift現象越嚴重。
為了說明偶數卷積對信息的侵蝕作用,文章定義feature的L1范數為該feature map信息量的度量。
基於這個定義,文章實驗了不同kernel size的信息量,得到如圖所示結果:
可以看到C3和C5整體信息量銳減的比C2和C4慢的多。
Symmetric padding
為了解決shift帶來的影響,本文提出了一種padding方式,具體的操作如圖所示:
即先將feature map分成四個group,每個group在不同方向上按如圖所示的方式進行padding,最后不用padding直接conv2x2即可。
這樣做的好處就是使得網絡的某些channel能利用到特定方向的信息,從宏觀上看網絡利用到了各個方向的信息,一定程度上緩解了shift帶來的問題。
Codding
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpConv2d(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,stride,padding,*args,**kwargs):
super(SpConv2d,self).__init__()
self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
def forward(self,x):
n,c,h,w = x.size()
assert c % 4 == 0
x1 = x[:,:c//4,:,:]
x2 = x[:,c//4:c//2,:,:]
x3 = x[:,c//2:c//4*3,:,:]
x4 = x[:,c//4*3:c,:,:]
x1 = nn.functional.pad(x1,(1,0,1,0),mode = "constant",value = 0) # left top
x2 = nn.functional.pad(x2,(0,1,1,0),mode = "constant",value = 0) # right top
x3 = nn.functional.pad(x3,(1,0,0,1),mode = "constant",value = 0) # left bottom
x4 = nn.functional.pad(x4,(0,1,0,1),mode = "constant",value = 0) # right bottom
x = torch.cat([x1,x2,x3,x4],dim = 1)
return self.conv(x)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv = SpConv2d(4,16,2,1,0)
def forward(self,x):
return self.conv(x)
if __name__ == "__main__":
x = torch.randn(2,4,14,14)
net = Net()
print(net(x))