Making Convolutional Networks Shift-Invariant Again
Intro
本文提出解決CNN平移不變性喪失的方法,之前說了CNN中的downsample過程由於不滿足采樣定理,所以沒法確保平移不變性。信號處理里面解決這樣的問題是利用增大采樣頻率或者用抗混疊方法,前者在圖像處理里面設置stride 1就可實現,但stride 1已經是極限,本文着重於后者,使用抗混疊使得CNN重新具有平移不變性。
混疊是在采樣頻率不滿足采樣定理時出現的一種現象,抗混疊通過抗混疊濾波器消除混疊,即先用低通濾波器處理,然后再去采樣,這樣可以消除高頻信號造成的不滿足采樣定理的情況。那么在圖像處理里,理論上avg pooling就可以減小高頻影響,但是有相關研究表明max-pooling在結果上要優於avg pooling(不考慮平移不變性只考慮分類等結果),但是max-pooling是不滿足采樣定理的,這就很尷尬。
Methods
先解釋兩個概念
平移不變性:指的是輸入平移一定距離,最終的結果不變,分類里面就是分類的概率結果是不變的。
平移同變形:指的是輸入平移一定距離,其對應的feature也做同樣的平移。
本文主要是針對特征的平移同變性去解決問題,而實際上實現了特征的平移同變形,后面接的是fc層,最后一層的平移不變性是等價於平移同變性的,所以實現了特征的平移同變性就是實現了整個網絡輸出的平移不變性。例如,vgg網絡的最后兩層是fc層和softmax,顯然fc層的spatial dim只有唯一一個元素(高維向量),所以平移不變性和平移等變性在這一層是等價的。
作者認為,max-pooling的過程可以分為兩個過程,max操作和采樣操作,其中max操作是平移等變的,因為max操作是利用滑動窗口實現的(stride 1),然后進行最大采樣,就實現了下采樣,而在采樣過程中保留了高頻部分(采樣頻率又相對較低),所以會導致不滿足采樣定理。為了使max-pooling滿足采樣定理,對采樣之前的信號用低通濾波處理,即可實現混疊。低通濾波加上采樣操作,作者成為BlurPool。
對於stride 2 conv,同樣地,卷積過程其實也分為兩個過程,對采樣前的信號低通濾波處理即可:conv(stride k)-relu替換為conv(stride 1)-relu-BlurPool(k)。avg pooling同理。
操作如圖所示。
Coding
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Downsample(nn.Module):
def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
super(Downsample, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride-1)/2.)
self.channels = channels
if(self.filt_size==1):
a = np.array([1.,])
elif(self.filt_size==2):
a = np.array([1., 1.])
elif(self.filt_size==3):
a = np.array([1., 2., 1.])
elif(self.filt_size==4):
a = np.array([1., 3., 3., 1.])
elif(self.filt_size==5):
a = np.array([1., 4., 6., 4., 1.])
elif(self.filt_size==6):
a = np.array([1., 5., 10., 10., 5., 1.])
elif(self.filt_size==7):
a = np.array([1., 6., 15., 20., 15., 6., 1.])
filt = torch.Tensor(a[:,None]*a[None,:])
filt = filt/torch.sum(filt)
self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, inp):
if(self.filt_size==1):
if(self.pad_off==0):
return inp[:,:,::self.stride,::self.stride]
else:
return self.pad(inp)[:,:,::self.stride,::self.stride]
else:
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
class AntialiasNet(nn.Module):
def __init__(self,num_classes = 10):
super(AntialiasNet,self).__init__()
self.net = nn.Sequential(
nn.Conv2d(1,32,3,stride = 1),
nn.BatchNorm2d(32),
nn.ReLU(inplace = True),
nn.Conv2d(32,64,3,stride = 1),
nn.BatchNorm2d(64),
nn.ReLU(inplace = True),
Downsample(channels = 64),
nn.Conv2d(64,128,3,stride = 1),
nn.BatchNorm2d(128),
nn.ReLU(inplace = True),
nn.Conv2d(128,256,3,stride = 1),
nn.BatchNorm2d(256),
nn.ReLU(inplace = True),
Downsample(channels = 256),
)
self.avg_pool = nn.AdaptiveAvgPool2d((7,7))
self.classifier = nn.Sequential(
nn.Linear(256 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
def forward(self,x):
return self.classifier(self.avg_pool(self.net(x)).view(-1,256*7*7))
def get_pad_layer(pad_type):
if(pad_type in ['refl','reflect']):
PadLayer = nn.ReflectionPad2d
elif(pad_type in ['repl','replicate']):
PadLayer = nn.ReplicationPad2d
elif(pad_type=='zero'):
PadLayer = nn.ZeroPad2d
else:
print('Pad type [%s] not recognized'%pad_type)
return PadLayer
if __name__ == "__main__":
x = torch.randn(3,1,28,28)
net = AntialiasNet()
print(net(x))