Spatial Transformer Networks
簡介
本文提出了能夠學習feature仿射變換的一種結構,並且該結構不需要給其他額外的監督信息,網絡自己就能學習到對預測結果有用的仿射變換。因為CNN的平移不變性等空間特征一定程度上被pooling等操作破壞了,所以,想要網絡能夠應對平移的object或者其他仿射變換后的object有更好的表示,就需要設計一種結構來學習這種變換,使得作用了這種變換后的feature能夠能好的表示任務。
網絡結構
上圖中U表示輸入feature map,通過spatial transformer 分支學習到transform,然后通過差值或其他sampler映射到輸出feature,這樣輸出的feature會有一種更加健壯的表示。
spatial transform的結構由三個部分組成,下面會詳細介紹。
仿射變換
仿射變換分為平移、縮放、翻轉、旋轉和裁剪這幾種變換,其中二維的變換可以用矩陣來表示:
其中theta對應取不同的值會對應不同的變換。所以網絡同學學習到這種變換,幫助feature得到一種更加有效的表示。
Localisation Network
該部分對應與上圖中的localisation net部分,目的是為了學習到上面公式中的theta參數,也就是說,這一部分的結構可以直接全連接6個theta或者使用conv結構,只要能映射到6個theta就可以了。這一部分比較簡單。
Parameterised Sampling Grid
這一部分對應於上圖的Grid Generator部分,這一部分的作用是建立輸入圖像位置到輸出圖像位置的映射,也就是對應於我們上面提到的仿射變換,我們在這一結構下可以通過上面學習到的參數theta來通過矩陣形式對輸入進行放放射變換,注意變換的時候每個channel的變換應該是一致的。公式表示為:
我們可以通過限定theta的取值來限定網絡只學習某種變換,也就是只學習一部分theta參數。
Differentiable Image Sampling
上面放射變換只是定義了變換前到變換后的位置映射,這個映射其實並不完整,這就意味着有些點是沒有值的,如果要給值,就要使用插值的方法了。論文中提到了最鄰近插值和雙線性插值兩種插值方法。
對於最鄰近插值給出了這樣的定義:
這樣對於輸出feature的第i個值,其對應的輸入feature的位置取決於m和n,由krnoecker delta函數定義知,當且僅當自變量為0時輸出為1.所以上式只有在m取得x方向上距離對應點最近的整數點以及n取得y方向上距離最近的整數點時有值,其值就為對應兩個方向都最近的點的值。
對於雙線性插值給出了這樣的定義:
由上式可以知道,只有當m和n取值為對應點xy方向上距離為1以內的整數時才有值,而距離對應點最近的整數點是有四個的,比如(0.5,0.5)距離其最近的四個點分別為(0,0),(0,1),(1,1),(1,0),后面兩個取值就成了距離權重,前面U取值為四個點之一的整數點的值,所以這個式子可以解釋為以距離作為權重,取最近的四個點的值的加權求和。
反向傳播
定義了上面的對應函數,作者證明了輸出到輸入是可以進行反向傳播的,以雙線性插值為例:
import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F
from torchsummary import summary
class STN(nn.Module):
def __init__(self):
super(STN,self).__init__()
self.feature_extractor = vgg16(pretrained = False).features
self.conv = nn.Conv2d(512,256,7)
self.fc = nn.Sequential(
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,6)
)
def forward(self,x):
features = self.feature_extractor(x) # (b,c,h,w) h = w = 7 c = 512
theta = self.conv(features).view(-1,256) # b,256
theta = self.fc(theta).view(-1,2,3) # b,2
transformed = F.affine_grid(theta,x.size()) # theta (n,2,3) size (n,c,h,w) ,這一步是得到仿射變換的映射
x = F.grid_sample(x,transformed) # 這一步就是根據映射關系,去做插值,得到變換后的圖像
return x
if __name__ == "__main__":
net = STN()
summary(net,(3,224,224),device = "cpu")
[Running] python -u "/media/xueaoru/DATA/ubuntu/six/STN.py"
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1,180,160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2,359,808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2,359,808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2,359,808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Conv2d-32 [-1, 256, 1, 1] 6,422,784
Linear-33 [-1, 512] 131,584
ReLU-34 [-1, 512] 0
Linear-35 [-1, 6] 3,078
================================================================
Total params: 21,272,134
Trainable params: 21,272,134
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.40
Params size (MB): 81.15
Estimated Total Size (MB): 300.13
----------------------------------------------------------------
[Done] exited with code=0 in 2.511 seconds