論文: Xception: Deep Learning with Depthwise Separable Convolutions
論文地址: https://arxiv.org/abs/1610.02357
代碼地址:
參考博客:
1. 提出背景
Xception 是 google 繼 Inception 后提出的對 Inception v3 的另一種改進,主要采用 depthwise separable convolution來替代原來的 Inception v3 中的卷積操作.
1. Inception模塊簡介
Inception v3 的結構圖如下Figure1:
當時提出Inception 的初衷可以認為是:
特征提取和傳遞可以通過 1x1,3x3,5x5 conv以及pooling,究竟哪種提取特征方式好呢,Inception 結構將這個疑問留給網絡自己訓練,也就是將一個輸入同時輸出給這幾種特征提取方式,然后做Concatnate.
- Inception v3 和 Inception v1 主要的區別是將 5x5的卷積核換成了2個 3x3 卷積核的疊加.
2. Inception模型簡化
正如之前論文RexNeXt所說,Inception網絡太依賴於人工設計了。於是結合ResNeXt的思想,從Inception V3聯想到簡化的Inception結構,就是Figure 2.
3. Inception模型拓展
我們可以做個等效的變換,事實上效果是一樣的,有了Figure 3:
Figure 3 表示對於一個輸入,先用一個統一的 1x1 的卷積核卷積,然后再連接3個 3x3卷積,這三個操作只將前面 1x1 卷積結果的一部分作為自己的輸入(這里是1/3channel)的卷積.
既然如此,不如干脆點:
3x3的卷積核的個數和 1x1的輸出channel 一樣多,每個 3x3卷積都只和1個輸入的channel做卷積.
2.論文核心
2.1 Depthwise Separable Convolution 深度分離卷積
DepthWise卷積PointWise卷積,合起來稱作DepthWise Separable Convolution,該結構和常規卷積操作類似,可用來提取特征。但是相比較常規卷積操作,其參數量和運算成本較低,所以在一些輕量級網絡中會碰到這種結構,比如說MobileNet.
2.1.1 常規卷積操作
對於一張 5x5 像素,三通道彩色輸入圖片(5x5x3),經過3x3卷積核的層,假設輸出通道數量為4,則卷積核的shape為 3x3x3x4,最終輸出4個Feature Map.
- 如果為
padding=same
,則特征圖尺寸為5x5 padding=valid
,特征圖尺寸3x3
2.1.2 DepthWise Convolution
不同於常規卷積操作:
DepthWise Convolution的一個卷積核只負責一個通道,一個通道只能被一個卷積核卷積.
同樣對於這張5x5像素,三通道的彩色輸入圖片,DepthWise Convolution首先經過第一次卷積運算,不同於上面的常規卷積
DepthWise Convolution完全在二維平面進行,卷積核數量與上一層必須一致.(上一層通道與卷積核個數一致)
所以一個三通道的圖像經過運算后生成了3個Feature Map.
但是這就存在一個缺點,首先:
- DepthWise Convolution 完成后 Feature Map 數量和輸入層的通道數量相同,無法拓展Feature Map.
- 這種運算對輸入層的每個通道獨立進行卷積運算,沒有有效利用在相同空間位置上的 feature 信息.
因此采用 PointWise Convolution 來將這些Feature Map重新組合生成新的 Feature Map.
2.1.3 PointWise Convolution
PointWise Convolution 的運算與常規卷積運算非常相似,它的卷積核的尺寸為 1x1xM,其中M為上一層的通道數量:
這里PointWise Convolution 運算會將上一步的map在深度方向上進行加權組合,生成新的Feature map,有多少個卷積核就有多少個輸出Feature Maps.
有意思的是其實之前許多網絡,例如Inception v3用PointWise Convolution來做維度縮減來降低參數,這里用來聯系以及拓展Feature Maps。而DepthWise Convolution也並不是新出現的,它可以看做是分組卷積的特例,早在AlexNet就出現過.
3. 網絡結構
Xception作為Inception v3的改進,主要是在Inception v3的基礎上引入了depthwise separable convolution,在基本不增加網絡復雜度的前提下提高了模型的效果.
疑問
- 有些人會好奇為什么引入depthwise separable convolution沒有大大降低網絡的復雜度?
原因在於作者加寬了網絡,使得參數數量和Inception v3差不多,然后在這前提下比較性能.因此Xception目的不在於模型壓縮,而是提高性能.
4. 核心代碼
from keras.models import Model
from keras.layers import Dense, Input, BatchNormalization, Activation, add
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.utils import plot_model
def Xception():
input_shape = _obtain_input_shape(None, default_size=299, min_size=71, data_format='channels_last', require_flatten=True)
img_input = Input(shape=input_shape)
# Block 1
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 2
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = add([x, residual])
residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 3
x = Activation('relu')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = add([x, residual])
residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 4
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = add([x, residual])
# Block 5-12
for i in range(8):
residual = x
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = add([x, residual])
residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 13
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Block 13 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = add([x, residual])
# Block 14
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Block 14 part2
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# 全鏈接層
x = GlobalAveragePooling2D()(x)
x = Dense(1000, activation='softmax')(x)
return Model(inputs=img_input, outputs=x, name='xception')
if __name__ == '__main__':
model = Xception()
model.summary()
plot_model(model, show_shapes=True)