[論文理解] Spatial Transformer Networks


Spatial Transformer Networks

簡介

本文提出了能夠學習feature仿射變換的一種結構,並且該結構不需要給其他額外的監督信息,網絡自己就能學習到對預測結果有用的仿射變換。因為CNN的平移不變性等空間特征一定程度上被pooling等操作破壞了,所以,想要網絡能夠應對平移的object或者其他仿射變換后的object有更好的表示,就需要設計一種結構來學習這種變換,使得作用了這種變換后的feature能夠能好的表示任務。

網絡結構

上圖中U表示輸入feature map,通過spatial transformer 分支學習到transform,然后通過差值或其他sampler映射到輸出feature,這樣輸出的feature會有一種更加健壯的表示。

spatial transform的結構由三個部分組成,下面會詳細介紹。

仿射變換

仿射變換分為平移、縮放、翻轉、旋轉和裁剪這幾種變換,其中二維的變換可以用矩陣來表示:

\[\left(\begin{matrix} x' \\ y' \end{matrix}\right) = \left[\begin{matrix} \theta_1 & \theta_2 & \theta_3 \\ \theta_4 & \theta_5 & \theta_6 \\ \end{matrix}\right] \left(\begin{matrix} x \\ y\\ 1 \end{matrix}\right) \]

其中theta對應取不同的值會對應不同的變換。所以網絡同學學習到這種變換,幫助feature得到一種更加有效的表示。

Localisation Network

該部分對應與上圖中的localisation net部分,目的是為了學習到上面公式中的theta參數,也就是說,這一部分的結構可以直接全連接6個theta或者使用conv結構,只要能映射到6個theta就可以了。這一部分比較簡單。

Parameterised Sampling Grid

這一部分對應於上圖的Grid Generator部分,這一部分的作用是建立輸入圖像位置到輸出圖像位置的映射,也就是對應於我們上面提到的仿射變換,我們在這一結構下可以通過上面學習到的參數theta來通過矩陣形式對輸入進行放放射變換,注意變換的時候每個channel的變換應該是一致的。公式表示為:

我們可以通過限定theta的取值來限定網絡只學習某種變換,也就是只學習一部分theta參數。

Differentiable Image Sampling

上面放射變換只是定義了變換前到變換后的位置映射,這個映射其實並不完整,這就意味着有些點是沒有值的,如果要給值,就要使用插值的方法了。論文中提到了最鄰近插值和雙線性插值兩種插值方法。

對於最鄰近插值給出了這樣的定義:

這樣對於輸出feature的第i個值,其對應的輸入feature的位置取決於m和n,由krnoecker delta函數定義知,當且僅當自變量為0時輸出為1.所以上式只有在m取得x方向上距離對應點最近的整數點以及n取得y方向上距離最近的整數點時有值,其值就為對應兩個方向都最近的點的值。

對於雙線性插值給出了這樣的定義:

由上式可以知道,只有當m和n取值為對應點xy方向上距離為1以內的整數時才有值,而距離對應點最近的整數點是有四個的,比如(0.5,0.5)距離其最近的四個點分別為(0,0),(0,1),(1,1),(1,0),后面兩個取值就成了距離權重,前面U取值為四個點之一的整數點的值,所以這個式子可以解釋為以距離作為權重,取最近的四個點的值的加權求和。

反向傳播

定義了上面的對應函數,作者證明了輸出到輸入是可以進行反向傳播的,以雙線性插值為例:

import torch

import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F

from torchsummary import summary
class STN(nn.Module):
    def __init__(self):
        super(STN,self).__init__()
        self.feature_extractor = vgg16(pretrained = False).features
        self.conv = nn.Conv2d(512,256,7)
        self.fc = nn.Sequential(
            nn.Linear(256,512),
            nn.ReLU(),
            nn.Linear(512,6)
        )
    
    def forward(self,x):
        features = self.feature_extractor(x) # (b,c,h,w) h = w = 7 c = 512
        theta = self.conv(features).view(-1,256) # b,256
        theta = self.fc(theta).view(-1,2,3) # b,2
        transformed = F.affine_grid(theta,x.size()) # theta (n,2,3) size (n,c,h,w) ,這一步是得到仿射變換的映射
        x = F.grid_sample(x,transformed) # 這一步就是根據映射關系,去做插值,得到變換后的圖像
        return x

if __name__ == "__main__":
    net = STN()
    summary(net,(3,224,224),device = "cpu")



[Running] python -u "/media/xueaoru/DATA/ubuntu/six/STN.py"
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Conv2d-32            [-1, 256, 1, 1]       6,422,784
           Linear-33                  [-1, 512]         131,584
             ReLU-34                  [-1, 512]               0
           Linear-35                    [-1, 6]           3,078
================================================================
Total params: 21,272,134
Trainable params: 21,272,134
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.40
Params size (MB): 81.15
Estimated Total Size (MB): 300.13
----------------------------------------------------------------

[Done] exited with code=0 in 2.511 seconds


論文原文:https://arxiv.org/pdf/1506.02025.pdf


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM