本文介紹針對一篇移動端自動設計網絡的文章《MnasNet:Platform-Aware Neural Architecture Search for Mobile》,由Google提出,很多后續工作都是基於這個工作改進的,因此很有必要學習了解。
Related work
MnasNet的目的很簡單就是設計出表現又好,效率又高的網絡。在介紹之前簡單回顧一下現有的一些提高網絡效率的方法:
- quantization:就是把模型的權重用更低精度表示,例如之前使用float32來存儲權重,那么我們可以試着用8位來存,更極致的思路是0,1來存,這就是Binary Network,也有一些工作研究這個,本文不做細究。
- pruning:就是把模型中不重要的參數刪掉。常用的一種剪枝方法是對通道數進行剪枝,因為這種方法實現起來方便,得到的模型結構也是規則的,計算起來也方便。
- 人工設計模塊
- ShuffleNet
上圖(a)就是加入Depthwise的ResNet bottleneck結構,而(b)和(c)是加入Group convolution和Channel Shuffle的ShuffleNet的結構。 - MobileNet:引入Depthwise Separable Convolution (DWConv)
- MobileNetv2:在DWConv基礎上引入inverted residuals and linear bottlenecks
- SqueezeNet
卷積模塊設計思路如下圖示,首先使用1x1卷積對輸入特征圖做壓縮,所以叫做Squeeze層;壓縮之后需要經過Expand層還原,這里會對壓縮后的特征做兩路還原,一路用1x1卷積,另一路用3x3卷積,最后對兩路的結果做concat。
看下圖可能會更加有助於理解:
- CondenseNet: 參考文章CondenseNet算法筆記
- ShuffleNet
MnasNet算法介紹
優化目標
之前的NAS算法(如DARTS,ENAS)考慮更多的是模型最終結果是否是SOTA,MnasNet則是希望搜索出又小又有效的網絡結構,因此將多個元素作為優化指標,包括准確率,在真實移動設備上的延遲等,最終定義的優化函數如下:
上式中個符號含義如下:
- \(m\)表示模型(model)
- \(ACC(m)\)表示在特定任務上的結果(如准確率)
- \(LAT(m)\)表示在設備上測得的實際計算延遲時間
- \(T\)表示目標延遲時間(target latency)
- \(w\)表示不同場景下對latency的控制因子。當實測延遲時間\(LAT(m)\)小於目標延遲時間\(T\)時,\(w=α\);反之\(w=β\)
上面式子其實表示為帕累托最優,因為一般而言延遲越長,代表模型越大,即參數越大,相應地模型結果也會越好;反之延遲越小,模型表現也會有略微下降。
文中提到latency單位提升會帶來5%的acc提升。也就是說假如模型A最終延遲為t,准確率為a;模型B延遲為2t,那么它的准確率應該是a(1+5%)。但是這兩個模型的reward應該是相等地,套用上面的公式有
求解得到\(\alpha=\beta=-0.7\)
搜索空間
之前的NAS算法都是搜索出一個比較好的cell,然后重復堆疊若干個cell得到最終的網絡,這種方式很明顯限制了網絡的多樣性。MnasNet做了一些改進可以讓每一層不一樣,具體思路是將模型划分成若干個block,每個block可以由不同數量的layer組成,每個layer則由不同的operation來表示,
Net
|__block
|__layer
|___operations
示意圖如下:
可以看到搜索空間包含如下:
- 標准卷積,深度可分離卷積(DWConv), MBConv(即上面提到的MobileNetV2的卷積模塊)
- 卷積核大小:3, 5, 7等
- Squeeze-and-excitation ratio (SE-Ratio): 0, 0.25
- Skip-connection
- 輸出通道數
- 不同block中的layer數量 \(N_i\)
搜索算法
和ENAS一樣使用的是強化學習進行搜索,這里不做細究(其實論文里也沒怎么說)。
實驗
實驗設置
之前的算法都是先在CIFAR10上搜索得到網絡后,再在ImageNet上訓練一個更大的網絡。MnasNet則是直接在ImageNet上搜網絡,但是只是在訓練集上搜了5個epoch。
實驗結果
ImageNet實驗結果
下圖中的結果和預期一樣,延遲越高,結果會稍微好一些。
作者還對比了SE模塊的效果,結果如下,可以看到效果還是不錯的。
有的時候為了適應實際場景需要,我們會對模型的通道數量進行修改,例如都砍掉一半或者增加一倍等,這樣就可以達到模型大小減小或增大的作用了,這個可以由depth multipilier
參數表示。但是有下面的結果可以看出和MobileNetV2相比,基於MnasNet找到的網絡對於通道數量變化魯棒性更強(左圖),同樣對於輸入數據大小也更加具有魯棒性(右圖)。
消融實驗(Ablation Study)
Soft vs. Hard Latency Constraint
前面介紹過用於控制延遲時間的因子 \(\alpha\)和\(\beta\),實驗對比了兩組參數設置:
- \(\alpha=0,\beta=-1\)
- \(\alpha=-0.07,\beta=-0.07\)。實驗結果如下:
設置的目標延遲時間為75ms,可以看到第二個參數配置能夠覆蓋更加廣的模型結構
多目標優化和搜索空間
這一個實驗探究的是本文提出的多目標優化和搜索空間的有效性,一共設置了三組實驗,其中baseline是NASNet,實驗結果如下:
可以看到多目標優化能夠找到延遲更小的網絡,而Mnas提出的搜索空間對模型表現也有一定提升。
MnasNet結構和Layer多樣性
下圖給出了搜索得到的MnasNet的結構,可以看到每層結構都不太一樣,不像之前的算法是簡單地疊加而成。
最后作者還對比了使用單一操作組成的網絡結果對比,實驗結果如下,可以看到雖然只使用MBConv5(k5x5)最終accuracy最高,但是他的推理延遲也很高,所以綜合來看還是MnasNet-A1表現最好。