文章轉自微信公眾號:【機器學習煉丹術】
參考目錄:
- 論文名稱:“Non-local Neural Networks”
- 論文地址:https://arxiv.org/abs/1711.07971
0 概述
首先,這個論文中的模塊,叫做non-local block,然后這個思想是基於NLP中的self-attention自注意力機制的。所以在提到CV中的self-attention,最先想到的就是non-local這個論文。這個論文提出的動機如下:
卷積運算和遞歸操作都在空間或時間上處理一個local鄰域;只有在重復應用這些運算、通過數據逐步傳播信號時,才能捕獲long-range相關性。
換句話說,在卷積網絡中,想要增加視野域,就要不斷的增加卷積層數量和池化層數量,換句話說,增加視野域就是增加網絡的深度。這樣必然會增加計算的成本,參數的數量,還需要考慮梯度消失問題。
- long-time相關性在NLP中,就是指一句話中,兩個距離很遠的單詞的相關性,在CV中則是指一個圖片中距離很遠的兩個部分的相關性。一般CNN識別物體,都是只關注物體周圍的像素,而不會考慮很遠的地方,這也就是CNN的一個特性,局部視野域。在某些情況下,這是天然優勢,當然也可能變成劣勢。
1 主要內容
本次我們學習先看代碼,然后再從論文中解析代碼。
1.1 Non local的優勢
- 通過少的參數,少的層數,捕獲遠距離的依賴關系;
- 即插即用
1.2 pytorch復現
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn,self).__init__()
self.chanel_in = in_dim
self.activation = activation
self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)
out = self.gamma*out + x
return out,attention
1.3 代碼解讀
輸入特征圖為BatchxChannelxHeightxWidth,我們先把這個輸入特征圖x分別放入:
- query卷積層,得到BatchxChannel//8xHeightxWidth
- key卷積層,得到BatchxChannel//8xHeightxWidth
- value卷積層,得到BatchxChannelxHeightxWidth
我們要逐個像素的計算query和key的相似度,然后相似度高的像素更為重要,相似度低的像素就不那么重要,每個像素我們用channel//8這個長度的向量來表示。(這里可能比較抽象,畢竟self-attention的原版是NLP領域的,non-local是從NLP中照搬過來的,所以不太好直接理解)
相似度計算是通過向量的乘法來表示的,那么我們肯定不能把這個HeightxWidth這么多像素一個一個計算像素的相似度。所以我們把BatchxChannel//8xHeightxWidth轉換成BatchxChannel//8xN的形式,這里的N是HeightxWidth,N表示圖中像素的數量。
然后我們用torch.bmm()
來做矩陣的乘法:(N,Channel//8)和(Channel//8,N)兩個矩陣相乘,得到一個(N,N)的矩陣。
這個(N,N)矩陣中的第i行第j列元素的值,是圖中i位置像素和j位置像素的相關性!然后我們把value矩陣和這個(N,N)再進行一次矩陣乘法,這樣得到的輸出,就是考慮了全局信息的特征圖了。
第二次矩陣乘法中,是(Channel,N)和(N,N)的相乘,得到的輸出的特征圖中的每一個值,都是N個值的加權平均,這也說明了輸出的特征圖中的每一個值,都是考慮了整張圖的像素的。
1.4 論文解讀
上圖是論文中對於non-local的結構圖。可以看到,先通過1x1的卷積,降低一下通道數,然后通過\(\theta和\phi\)分別是query和key,然后這兩個卷積得到(N,N)的矩陣,然后再與\(g\)(value)進行矩陣乘法。
好吧我承認和代碼在通道數上略微有些出入,但是大體思想相同。
2 總結
- 經過了non-local的特征圖,視野域擴大到了全圖,而且並沒有增加很多的參數。
- 但是因為經過了BMM矩陣呢的乘法,梯度計算圖急速擴大,因此計算和內存會消耗很大。因此,我在網絡的深層(特征圖尺寸較小的時候),才會加上一層non-local層。但是!!!論文中說,盡量放在靠前的層,所以在計算力允許的情況下,往前放。
- 這個方法在一部分的任務中,確實有提升,我自己試過,還可以。
- 后續又有很多來降低這個計算消耗的算法,之后我們在講,喜歡的點個關注和贊吧~