Xception


論文: Xception: Deep Learning with Depthwise Separable Convolutions

論文地址: https://arxiv.org/abs/1610.02357

代碼地址:

  1. Keras: https://github.com/yanchummar/xception-keras

參考博客:

  1. Xception算法詳解
  2. Depthwise卷積與Pointwise卷積

1. 提出背景

Xception 是 google 繼 Inception 后提出的對 Inception v3 的另一種改進,主要采用 depthwise separable convolution來替代原來的 Inception v3 中的卷積操作.

1. Inception模塊簡介

Inception v3 的結構圖如下Figure1:

當時提出Inception 的初衷可以認為是:

特征提取和傳遞可以通過 1x1,3x3,5x5 conv以及pooling,究竟哪種提取特征方式好呢,Inception 結構將這個疑問留給網絡自己訓練,也就是將一個輸入同時輸出給這幾種特征提取方式,然后做Concatnate.

  • Inception v3 和 Inception v1 主要的區別是將 5x5的卷積核換成了2個 3x3 卷積核的疊加.

2. Inception模型簡化

正如之前論文RexNeXt所說,Inception網絡太依賴於人工設計了。於是結合ResNeXt的思想,從Inception V3聯想到簡化的Inception結構,就是Figure 2.

3. Inception模型拓展

我們可以做個等效的變換,事實上效果是一樣的,有了Figure 3:

Figure 3 表示對於一個輸入,先用一個統一的 1x1 的卷積核卷積,然后再連接3個 3x3卷積,這三個操作只將前面 1x1 卷積結果的一部分作為自己的輸入(這里是1/3channel)的卷積.

既然如此,不如干脆點:

3x3的卷積核的個數和 1x1的輸出channel 一樣多,每個 3x3卷積都只和1個輸入的channel做卷積.

2.論文核心

2.1 Depthwise Separable Convolution 深度分離卷積

DepthWise卷積PointWise卷積,合起來稱作DepthWise Separable Convolution,該結構和常規卷積操作類似,可用來提取特征。但是相比較常規卷積操作,其參數量和運算成本較低,所以在一些輕量級網絡中會碰到這種結構,比如說MobileNet.

2.1.1 常規卷積操作

對於一張 5x5 像素,三通道彩色輸入圖片(5x5x3),經過3x3卷積核的層,假設輸出通道數量為4,則卷積核的shape為 3x3x3x4,最終輸出4個Feature Map.

  • 如果為padding=same,則特征圖尺寸為5x5
  • padding=valid,特征圖尺寸3x3

2.1.2 DepthWise Convolution

不同於常規卷積操作:

DepthWise Convolution的一個卷積核只負責一個通道,一個通道只能被一個卷積核卷積.

同樣對於這張5x5像素,三通道的彩色輸入圖片,DepthWise Convolution首先經過第一次卷積運算,不同於上面的常規卷積

DepthWise Convolution完全在二維平面進行,卷積核數量與上一層必須一致.(上一層通道與卷積核個數一致)

所以一個三通道的圖像經過運算后生成了3個Feature Map.

但是這就存在一個缺點,首先:

  1. DepthWise Convolution 完成后 Feature Map 數量和輸入層的通道數量相同,無法拓展Feature Map.
  2. 這種運算對輸入層的每個通道獨立進行卷積運算,沒有有效利用在相同空間位置上的 feature 信息.

因此采用 PointWise Convolution 來將這些Feature Map重新組合生成新的 Feature Map.

2.1.3 PointWise Convolution

PointWise Convolution 的運算與常規卷積運算非常相似,它的卷積核的尺寸為 1x1xM,其中M為上一層的通道數量:

這里PointWise Convolution 運算會將上一步的map在深度方向上進行加權組合,生成新的Feature map,有多少個卷積核就有多少個輸出Feature Maps.

有意思的是其實之前許多網絡,例如Inception v3用PointWise Convolution來做維度縮減來降低參數,這里用來聯系以及拓展Feature Maps。而DepthWise Convolution也並不是新出現的,它可以看做是分組卷積的特例,早在AlexNet就出現過.

3. 網絡結構

Xception作為Inception v3的改進,主要是在Inception v3的基礎上引入了depthwise separable convolution,在基本不增加網絡復雜度的前提下提高了模型的效果.

疑問

  1. 有些人會好奇為什么引入depthwise separable convolution沒有大大降低網絡的復雜度?

原因在於作者加寬了網絡,使得參數數量和Inception v3差不多,然后在這前提下比較性能.因此Xception目的不在於模型壓縮,而是提高性能.

4. 核心代碼

from keras.models import Model
from keras.layers import Dense, Input, BatchNormalization, Activation, add
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.utils import plot_model


def Xception():
    input_shape = _obtain_input_shape(None, default_size=299, min_size=71, data_format='channels_last', require_flatten=True)
    img_input = Input(shape=input_shape)

    # Block 1
    x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(64, (3, 3), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 2
    x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    x = add([x, residual])

    residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 3
    x = Activation('relu')(x)
    x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    x = add([x, residual])

    residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 4
    x = Activation('relu')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    x = add([x, residual])

    # Block 5-12
    for i in range(8):
        residual = x
        x = Activation('relu')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
        x = BatchNormalization()(x)

        x = add([x, residual])

    residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
    residual = BatchNormalization()(residual)

    # Block 13
    x = Activation('relu')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
    x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)

    # Block 13 Pool
    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
    x = add([x, residual])

    # Block 14
    x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Block 14 part2
    x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # 全鏈接層
    x = GlobalAveragePooling2D()(x)
    x = Dense(1000, activation='softmax')(x)

    return Model(inputs=img_input, outputs=x, name='xception')


if __name__ == '__main__':
    model = Xception()
    model.summary()
    plot_model(model, show_shapes=True)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM