https://zhuanlan.zhihu.com/p/28186857
這個例子說明了什么叫做空間可分離卷積,這種方法並不應用在深度學習中,只是用來幫你理解這種結構。
在神經網絡中,我們通常會使用深度可分離卷積結構(depthwise separable convolution)。
這種方法在保持通道分離的前提下,接上一個深度卷積結構,即可實現空間卷積。接下來通過一個例子讓大家更好地理解。
假設有一個3×3大小的卷積層,其輸入通道為16、輸出通道為32。具體為,32個3×3大小的卷積核會遍歷16個通道中的每個數據,從而產生16×32=512個特征圖譜。進而通過疊加每個輸入通道對應的特征圖譜后融合得到1個特征圖譜。最后可得到所需的32個輸出通道。
針對這個例子應用深度可分離卷積,用16個3×3大小的卷積核分別遍歷16通道的數據,得到了16個特征圖譜。在融合操作之前,接着用32個1×1大小的卷積核遍歷這16個特征圖譜,進行相加融合。這個過程使用了16×3×3+16×32×1×1=656個參數,遠少於上面的16×32×3×3=4608個參數。
這個例子就是深度可分離卷積的具體操作,其中上面的深度乘數(depth multiplier)設為1,這也是目前這類網絡層的通用參數。
這么做是為了對空間信息和深度信息進行去耦。從Xception模型的效果可以看出,這種方法是比較有效的。由於能夠有效利用參數,因此深度可分離卷積也可以用於移動設備中。
src convolution
input output
M*N*Cin M*N*Cout
16*3*3*32
depthwise separable convolution
input output1 output2
M*N*Cin M*N*Cin M*N*Cout
16*3*3 16*32*1*1
另外一個地方看到的解釋:
MobileNet-v1:
MobileNet主要用於移動端計算模型,是將傳統的卷積操作改為兩層的卷積操作,在保證准確率的條件下,計算時間減少為原來的1/9,計算參數減少為原來的1/7.
MobileNet模型的核心就是將原本標准的卷積操作因式分解成一個depthwise convolution和一個1*1的pointwise convolution操作。簡單講就是將原來一個卷積層分成兩個卷積層,其中前面一個卷積層的每個filter都只跟input的每個channel進行卷積,然后后面一個卷積層則負責combining,即將上一層卷積的結果進行合並。
depthwise convolution:
比如輸入的圖片是Dk*Dk*M(Dk是圖片大小,M是輸入的渠道數),那么有M個Dw*Dw的卷積核,分別去跟M個渠道進行卷積,輸出Df*Df*M結果
pointwise convolution:
對Df*Df*M進行卷積合並,有1*1*N的卷積,進行合並常規的卷積,輸出Df*Df*N的結果
上面經過這兩個卷積操作,從一個Dk*Dk*M=>Df*Df*N,相當於用Dw*Dw*N的卷積核進行常規卷積的結果,但計算量從原來的DF*DF*DK*DK*M*N減少為DF*DF*DK*DK*M+DF*DF*M*N.
第一層為常規卷積,后面接着都為depthwise convolution+pointwise convolution,最后兩層為Pool層和全連接層,總共28層.
下面的代碼是mobilenet的一個參數列表,計算的普通卷積與深度分離卷積的計算復雜程度比較
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py
# Tensorflow mandates these. from collections import namedtuple import functools import tensorflow as tf slim = tf.contrib.slim # Conv and DepthSepConv namedtuple define layers of the MobileNet architecture # Conv defines 3x3 convolution layers # DepthSepConv defines 3x3 depthwise convolution followed by 1x1 convolution. # stride is the stride of the convolution # depth is the number of channels or filters in a layer Conv = namedtuple('Conv', ['kernel', 'stride', 'depth']) DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth']) # _CONV_DEFS specifies the MobileNet body _CONV_DEFS = [ Conv(kernel=[3, 3], stride=2, depth=32), DepthSepConv(kernel=[3, 3], stride=1, depth=64), DepthSepConv(kernel=[3, 3], stride=2, depth=128), DepthSepConv(kernel=[3, 3], stride=1, depth=128), DepthSepConv(kernel=[3, 3], stride=2, depth=256), DepthSepConv(kernel=[3, 3], stride=1, depth=256), DepthSepConv(kernel=[3, 3], stride=2, depth=512), DepthSepConv(kernel=[3, 3], stride=1, depth=512), DepthSepConv(kernel=[3, 3], stride=1, depth=512), DepthSepConv(kernel=[3, 3], stride=1, depth=512), DepthSepConv(kernel=[3, 3], stride=1, depth=512), DepthSepConv(kernel=[3, 3], stride=1, depth=512), DepthSepConv(kernel=[3, 3], stride=2, depth=1024), DepthSepConv(kernel=[3, 3], stride=1, depth=1024) ] input_size = 160 inputdepth = 3 conv_defs = _CONV_DEFS sumcost = 0 for i, conv_def in enumerate(conv_defs): stride = conv_def.stride kernel = conv_def.kernel outdepth = conv_def.depth output_size = round((input_size - int(kernel[0] / 2) * 2) / stride) if isinstance(conv_def, Conv): sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth if isinstance(conv_def, DepthSepConv): sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth inputdepth = outdepth input_size = output_size print("src conv: ", sumcost) input_size = 160 inputdepth = 3 conv_defs = _CONV_DEFS sumcost1 = 0 for i, conv_def in enumerate(conv_defs): stride = conv_def.stride kernel = conv_def.kernel outdepth = conv_def.depth output_size = round((input_size - int(kernel[0] / 2) * 2) / stride) if isinstance(conv_def, Conv): sumcost1 += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth if isinstance(conv_def, DepthSepConv): #sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth sumcost1 += output_size * output_size *(inputdepth * kernel[0] * kernel[0] + inputdepth * outdepth * 1 * 1) inputdepth = outdepth input_size = output_size print("DepthSepConv:", sumcost1) print("compare:", sumcost1 / sumcost)
src conv: 1045417824
DepthSepConv: 126373376
compare: 0.12088312739538674
mobilenet V1介紹