DARTS
2019-ICLR-DARTS Differentiable Architecture Search
來源:ChenBong 博客園
- Institute:CMU、Google
- Author:Hanxiao Liu、Karen Simonyan、Yiming Yang
- GitHub:2.8k stars
- https://github.com/quark0/darts
- https://github.com/khanrc/pt.darts
- Citation:557
問題
&& 更新結構參數α時, 有用到指數移動平均EMA嗎?
沒有
&& op的padding操作, 是先padding再卷積, 還是先卷積再padding?
先padding再卷積
&& FactorizedReduce() 函數的作用?
將feature map縮小為原來的一半
&& Reduction Cell的哪個Node的Stride=2? Reduction Cell中Node的具體輸入輸出?
不是reduction cell中的node 的stride=2,而是reduction cell的預處理的stride=2,具體見離散網絡結構 部分
&& Cell_3 Node_0 的size預處理是什么?
# 如果[k-1] cell 是reduction cell, 當前cell的input size=[k-1] cell 的 output size, 因此不匹配[k-2] cell 的 output size # 因此[k-2] cell 的output需要 reduce 處理 if reduction_p: # 如果[k-1] cell 是reduction cell: 將feature map縮小為原來的一半 # input node_0: 處理[k-2]cell的output self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) else: # 如果[k-1] cell 不是reduction cell: 標准1x1卷積 # input node_0: 處理[k-2]cell的output self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
&& α/w的更新, 是以batch為單位還是epochs為單位?
以batch為單位
&& 更新α用的優化器是什么? 具體參數? 具體操作?
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),weight_decay=1.0E-3)
&& 實際上權重的更新時怎么做的? 只更新一步嗎?
一階近似時,更新一次;
二階近似時,
&& 用val set 更新α, 用train set 更新w, 數據集划分?
val set 為 cifar10 的 test set
Introduction
Motivation
之前的NAS方法:
- 高昂的計算代價:2000/3000 GPU days
- 離散的搜索空間,導致大量的結構需要評估
Contribution
- 基於梯度下降的可微分方法
- 可以用在CNN和RNN上
- 在CIFAR-10和PTB數據集上達到SOTA
- 高效性:2000 GPU days vs 4 GPU days
- 可遷移性:在cifar10上搜索的結構遷移到ImageNet上,在PTB上搜索的結構遷移到WikiText-2上
Method
搜索空間
搜索cell結構作為最終網絡結構的構建塊(building block)
搜素到的cell可以堆疊構成CNN或者RNN
一個cell是一個包含N個節點的有向無環圖(DAG)
圖1說明:
圖1表示一個cell結構;每個節點都會連接到比自身編號小的節點上;
節點 i 表示feature maps(\(x^{(i)}\)),節點之間不同顏色的箭頭表示不同op,每個op都有自己的權重;
節點之間的操作選自op集O, 兩個節點之間的op數=|O|;
節點 i, j 之間的每個op都對應一個結構參數(\(α^{(i, j)}\))(可以理解為該op的強度/權重等),\(α^{(i,j)}\) 是一個|O|維的向量;
\(x^{(j)}=\sum_{i<j} o^{(i, j)}\left(x^{(i)}\right) \qquad (1)\)
公式(1)說明:
- \(x^{(i)}\) 表示第i個節點的feature map
- \(o^{(i, j)}\) 是一組op集合
- \(o^{(i, j)}\left(x^{(i)}\right)\) 表示對feature map \(x^{(i)}\) 執行op集 \(o^{(i, j)}\) 得到新的feature map
- 對所有小於j的節點i,都執行 \(o^{(i, j)}\left(x^{(i)}\right)\) ,並將結果求和,得到 j 節點的feature map
\(\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) \qquad (2)\)
公式(2)說明:
- 向量 \(α^{ij}\) 的維度是|O|
- 對 \(α^{(i, j)}\) 執行softmax,得到softmax后的結構參數 \(\hat{α}^{(i, j)}=\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}\)
- 將op集合O中的每個op都施加在x上,並乘以對應的結構參數 \(\hat{α}^{(i, j)}\) ,再求和,得到 \(\bar{o}^{(i, j)}(x)\)
- 則mix op 記為 \(\bar{o}^{(i, j)}(·)\)
&& 兩個node之間, 不同顏色的op的output feature map size 不一樣怎么辦?
兩個node之間, input size相同, 由於op類型不同, 會導致不同op輸出的output size不同, 代碼中是通過padding來保持不同op的output feature map維度統一的
&& 兩個node之間, 不同顏色的op的output feature maps 是如何整合的? 求和還是concat?
不同op的的output feature maps (通道數和size都相同) 會進行求和 (對應位置元素相加), 因此多個op的output feature maps 整合后, feature map的通道數和size都不變
&& 來自不同node的 output feature maps 如何整合?求和還是concat?
求和
\(o^{(i, j)}=\operatorname{argmax}_{o \in \mathcal{O}} \alpha_{o}^{(i, j)}\)
公式說明:
在搜索的最后階段三條不同顏色的線會保留對應結構參數 \(α^{(i, j)}\) 最大的那一條
結構圖例說明
**CNN cell結構: **
其中每個三角形代表圖1中兩個node之間的一組操作, 即每個三角形表示公式(2)的操作: \(\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) \qquad (2)\) , 到最后, 圖1中的每組線會只保留一個op, 即每個三角形到最后也只保留1個強度最大的op. 而且每個node會選擇n個op中強度最大的2個.

最后每個三角形(對應圖1中兩個node之間的一組操作)
CNN結構:
一個三角形表示圖1中兩個node之間的一組op:
\(\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) \qquad (2)\)

優化目標
我們的目標是聯合學習結構參數(α)和網絡權重(w):
\(\min _{\alpha} \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \qquad (3)\)
s.t. \(\quad w^{*}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{\text {train}}(w, \alpha) \qquad (4)\)
公式(3)(4)說明:
- \(w^{*}(\alpha)\) 是結構參數取值為 \(α\) 時最佳的網絡權重,即不同的 \(α\) 對應不同的最佳網絡權重 \(w^{*}(\alpha)\)
- 訓練流程:
- 每次改變 \(α\) ,先將網絡權重訓練到對應的最佳網絡權重 \(w^{*}(\alpha)\) ,——公式(4)
- 對結構參數 \(α\) 梯度下降,嘗試不同的結構參數 \(α\) ,找到使得loss最小的結構參數 \(α\),即找到了最佳的結構——公式(3)
算法(1)DARTS-可微分的結構搜索 說明:
- 根據結構參數 \(α^{(i, j)}\) 構建mix op \(\bar{o}^{(i, j)}(·)\)
- 若(還未收斂),執行:
- 梯度下降更新結構參數 \(α\) : \(\nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right)\)
- 梯度下降更新網絡權重 w: \(\nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha)\)
- 根據收斂后的 \(α\) 導出最終結構
近似處理
每次更新 \(α\) 后,如果重新訓練網絡權重到收斂,需要消耗大量時間,我們希望通過簡化公式(3),只更新一次,來近似逼近 \(w^{*}(\alpha)\) ,而不是通過訓練到收斂來獲得 \(w^{*}(\alpha)\)
\(\nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \qquad (5)\)
\((5)\approx \nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right) \qquad(6)\)
公式(6) 說明:
\(\xi\) 是一個超參, 代表權重一步優化的學習率, 在結構參數 \(α\) 更新后, 該公式通過只更新一步網絡權重(本來要更新到收斂), 來近似收斂后的網絡權重. 注意到如果w已經是最優時, 即 \(\nabla_{w} \mathcal{L}_{\text {train}}(w, \alpha)=0\) 時, (6)將退化為 \(\nabla_{\alpha} \mathcal{L}_{v a l}(w, \alpha)\)
\((6)=\nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)-\xi \nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \qquad (7)\)
(7)式中的 \(w^{\prime}=w-\xi \nabla_{w} L_{t r a i n}(w, \alpha)\)
公式(7)說明:
式(6)應用鏈式法則, 可得式(7)
&& (7)式后面包含了一個計算復雜度很高的矩陣乘法 $\xi \nabla_{\alpha, w}^{2} \mathcal{L}{t r a i n}(w, \alpha) \nabla{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) $ ,文中提出有限差分近似的方法解決, 如下.
設 \(w^{\pm}=w \pm \epsilon \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)\) 則:
\(\nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \approx \frac{\nabla_{\alpha} \mathcal{L}_{t r a i n}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{t r a i n}\left(w^{-}, \alpha\right)}{2 \epsilon} \qquad (8)\)
公式(8)說明:
&& 評估該 有限差分 僅需要兩次前向傳播即可得到 weights,兩次反向傳播,就可以得到 α,運算復雜度大大的降低了: \(O(|\alpha||w|)\) to \(O(|\alpha|+|w|)\)
理論上 \(\epsilon\) 要足夠小, 經驗上取 \(\epsilon=0.01 /\left\|\nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)\right\|_{2}\) 即可足夠精確.
\(\xi\) 取值討論:
- 當 \(\xi = 0\) , 式(7)中的二階導數將消失, 因此
- 設 \(\xi = 0\) , 此時為一階近似
- 設 \(\xi > 0\) , 此時為二階近似, 這種情況下, 簡單的策略是將 \(\xi\) 設置為網絡權重w的學習率
\(\xi\) 取值實驗:
設置簡單的損失函數:
- \(\mathcal{L}_{\text {val}}(w, \alpha)=\alpha w-2 \alpha+1\)
- \(\mathcal{L}_{\text {train}}(w, \alpha)=w^2-2\alpha w+ \alpha^2\)
\(\xi\) 取不同的值, 優化過程如下圖:
連續結構=>離散結構
為了構造離散的結構的cell中的每個節點(即邊上不存在結構參數 或者說 結構參數均為1),對於每個節點,我們都保留op強度最強的k個邊,對於CNN來說k=2,對於RNN來說k=1。
即下圖中,CNN cell 的每個node 都有k=2個輸入,RNN cell 的每個node 都有k=1個輸入。
&& 代碼中如何實現?
&& 堆疊cell以后, 多個cell是否是相同的? 如何實現?
op強度定義為: \(\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}\)
Experiments
以下是op集 O 中的op:
- 3 × 3 and 5 × 5 separable convolutions,
- 3 × 3 and 5 × 5 dilated separable convolutions,
- 3 × 3 max pooling,
- 3 × 3 average pooling,
- identity (skip connection?)
- zero.
所有的op:
- stride = 1(如有需要的話)
- 不同操作的feature map(分辨率可能不同)都進行pad以保持相同的分辨率
我們使用:
- 對於卷積操作,使用 ReLU-Conv-BN的順序
- 每個可分離卷積都應用兩次
- CNN cell包含N=7個Nodes,output node定義為所有中間節點(feature maps)的concat
&& concat維度不同如何處理?
每個cell包含2個input node,和1個output node
- 第k個cell 的 2個input node 分別等於 第k-2個cell 和 第 k-1 個cell的output node
- 位於網絡深度 1/3 和 2/3 的2個cell,設置為reduction cell,即cell中的op 的stride=2
- 因此有2種不用的cell,分別稱為Normal cell 和 Reduce cell,兩種cell的結構參數不同,分別稱為 \(α_{normal}, α_{reduce}\)
- 其中 \(α_{normal}\) 在所有 Normal cell 中共享,\(α_{reduce}\) 在所有 Reduce cell 中共享
- 為了確定最終的結構,我們用不同的 random seeds運行DARTS 4次,並將4次的結果train from scratch 少量epochs(100 epochs for CIFAR-10,300 epochs for PTB),根據訓練少量epochs后的性能來挑選最佳cell
- 由於cell要進行多次堆疊,因此運行多次搜索是必要的,而且結果可能是初始值敏感的,如下圖2,4:
結構評估
為了評估搜索到的結構,我們隨機初始化結構的權重(在搜索過程中學習的權重被拋棄),train from scratch,並報告了其在測試集上的權重。
結果分析
圖3說明:
- DARTS在減少3個數量級的計算量的基礎上達到了與SOTA相當的結果
- (i.e. 1.5 or 4 GPU days vs 2000 GPU days for NASNet and 3150 GPU days for AmoebaNet)
- 較長的搜索時間是由於我們對cell 的選擇重復搜索了4次,這種做法對CNN cell 來說不是特別重要,CNN cell 的初值敏感性較不明顯,RNN cell 的初值敏感性較大
表1說明:
- 從表1可以看出,隨機搜索的結果也具有競爭力,說明本方法搜索空間設計的較好。
表3說明:
- 在cifar10上搜索的cell,確實可以被遷移到ImageNet上。
表4說明:
- 表4中可看出,PTB與WT2之間的可遷移性較弱(與CIFAR-10和ImageNet的可遷移性相比),原因是用於搜索結構的源數據集(PTB)規模較小
- 可以直接對感興趣的數據集進行結構搜索,可以避免遷移性的問題
搜索過程中網絡輸入輸出的變化
CNN:==================================================================
CNN In: torch.Size([32, 3, 32, 32])
CNN stem In : torch.Size([32, 3, 32, 32])
CNN stem Out: torch.Size([32, 48, 32, 32]), torch.Size([32, 48, 32, 32])
Cell_0:========================
Cell_0 In: torch.Size([32, 48, 32, 32]) torch.Size([32, 48, 32, 32])
Preproc0_in: torch.Size([32, 48, 32, 32]), Preproc1_in: torch.Size([32, 48, 32, 32])
Preproc0_out: torch.Size([32, 16, 32, 32]), Preproc1_out: torch.Size([32, 16, 32, 32])
Node_0 In: 1 x torch.Size([32, 16, 32, 32])
Node_0 Out: 1 x torch.Size([32, 16, 32, 32])
Node_1 In: 1 x torch.Size([32, 16, 32, 32])
Node_1 Out: 1 x torch.Size([32, 16, 32, 32])
Node_2 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_2 Out: 1 x torch.Size([32, 16, 32, 32])
Node_3 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_3 Out: 1 x torch.Size([32, 16, 32, 32])
Node_4 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_4 Out: 1 x torch.Size([32, 16, 32, 32])
Node_5 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_5 Out: 1 x torch.Size([32, 16, 32, 32])
Cell_0 Out: torch.Size([32, 64, 32, 32])
Cell_1:========================
Cell_1 In: torch.Size([32, 48, 32, 32]) torch.Size([32, 64, 32, 32])
Preproc0_in: torch.Size([32, 48, 32, 32]), Preproc1_in: torch.Size([32, 64, 32, 32])
Preproc0_out: torch.Size([32, 16, 32, 32]), Preproc1_out: torch.Size([32, 16, 32, 32])
Node_0 In: 1 x torch.Size([32, 16, 32, 32])
Node_0 Out: 1 x torch.Size([32, 16, 32, 32])
Node_1 In: 1 x torch.Size([32, 16, 32, 32])
Node_1 Out: 1 x torch.Size([32, 16, 32, 32])
Node_2 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_2 Out: 1 x torch.Size([32, 16, 32, 32])
Node_3 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_3 Out: 1 x torch.Size([32, 16, 32, 32])
Node_4 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_4 Out: 1 x torch.Size([32, 16, 32, 32])
Node_5 In:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node pre_Out:
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
torch.Size([32, 16, 32, 32])
Node_5 Out: 1 x torch.Size([32, 16, 32, 32])
Cell_1 Out: torch.Size([32, 64, 32, 32])
Cell_2:========================
Cell_2 In: torch.Size([32, 64, 32, 32]) torch.Size([32, 64, 32, 32])
Preproc0_in: torch.Size([32, 64, 32, 32]), Preproc1_in: torch.Size([32, 64, 32, 32])
Preproc0_out: torch.Size([32, 32, 32, 32]), Preproc1_out: torch.Size([32, 32, 32, 32])
Node_0 In: 1 x torch.Size([32, 32, 32, 32])
Node_0 Out: 1 x torch.Size([32, 32, 32, 32])
Node_1 In: 1 x torch.Size([32, 32, 32, 32])
Node_1 Out: 1 x torch.Size([32, 32, 32, 32])
Node_2 In:
torch.Size([32, 32, 32, 32])
torch.Size([32, 32, 32, 32])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_2 Out: 1 x torch.Size([32, 32, 16, 16])
Node_3 In:
torch.Size([32, 32, 32, 32])
torch.Size([32, 32, 32, 32])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_3 Out: 1 x torch.Size([32, 32, 16, 16])
Node_4 In:
torch.Size([32, 32, 32, 32])
torch.Size([32, 32, 32, 32])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_4 Out: 1 x torch.Size([32, 32, 16, 16])
Node_5 In:
torch.Size([32, 32, 32, 32])
torch.Size([32, 32, 32, 32])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_5 Out: 1 x torch.Size([32, 32, 16, 16])
Cell_2 Out: torch.Size([32, 128, 16, 16])
Cell_3:========================
Cell_3 In: torch.Size([32, 64, 32, 32]) torch.Size([32, 128, 16, 16])
Preproc0_in: torch.Size([32, 64, 32, 32]), Preproc1_in: torch.Size([32, 128, 16, 16])
Preproc0_out: torch.Size([32, 32, 16, 16]), Preproc1_out: torch.Size([32, 32, 16, 16])
Node_0 In: 1 x torch.Size([32, 32, 16, 16])
Node_0 Out: 1 x torch.Size([32, 32, 16, 16])
Node_1 In: 1 x torch.Size([32, 32, 16, 16])
Node_1 Out: 1 x torch.Size([32, 32, 16, 16])
Node_2 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_2 Out: 1 x torch.Size([32, 32, 16, 16])
Node_3 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_3 Out: 1 x torch.Size([32, 32, 16, 16])
Node_4 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_4 Out: 1 x torch.Size([32, 32, 16, 16])
Node_5 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_5 Out: 1 x torch.Size([32, 32, 16, 16])
Cell_3 Out: torch.Size([32, 128, 16, 16])
Cell_4:========================
Cell_4 In: torch.Size([32, 128, 16, 16]) torch.Size([32, 128, 16, 16])
Preproc0_in: torch.Size([32, 128, 16, 16]), Preproc1_in: torch.Size([32, 128, 16, 16])
Preproc0_out: torch.Size([32, 32, 16, 16]), Preproc1_out: torch.Size([32, 32, 16, 16])
Node_0 In: 1 x torch.Size([32, 32, 16, 16])
Node_0 Out: 1 x torch.Size([32, 32, 16, 16])
Node_1 In: 1 x torch.Size([32, 32, 16, 16])
Node_1 Out: 1 x torch.Size([32, 32, 16, 16])
Node_2 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_2 Out: 1 x torch.Size([32, 32, 16, 16])
Node_3 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_3 Out: 1 x torch.Size([32, 32, 16, 16])
Node_4 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_4 Out: 1 x torch.Size([32, 32, 16, 16])
Node_5 In:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node pre_Out:
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
torch.Size([32, 32, 16, 16])
Node_5 Out: 1 x torch.Size([32, 32, 16, 16])
Cell_4 Out: torch.Size([32, 128, 16, 16])
Cell_5:========================
Cell_5 In: torch.Size([32, 128, 16, 16]) torch.Size([32, 128, 16, 16])
Preproc0_in: torch.Size([32, 128, 16, 16]), Preproc1_in: torch.Size([32, 128, 16, 16])
Preproc0_out: torch.Size([32, 64, 16, 16]), Preproc1_out: torch.Size([32, 64, 16, 16])
Node_0 In: 1 x torch.Size([32, 64, 16, 16])
Node_0 Out: 1 x torch.Size([32, 64, 16, 16])
Node_1 In: 1 x torch.Size([32, 64, 16, 16])
Node_1 Out: 1 x torch.Size([32, 64, 16, 16])
Node_2 In:
torch.Size([32, 64, 16, 16])
torch.Size([32, 64, 16, 16])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_2 Out: 1 x torch.Size([32, 64, 8, 8])
Node_3 In:
torch.Size([32, 64, 16, 16])
torch.Size([32, 64, 16, 16])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_3 Out: 1 x torch.Size([32, 64, 8, 8])
Node_4 In:
torch.Size([32, 64, 16, 16])
torch.Size([32, 64, 16, 16])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_4 Out: 1 x torch.Size([32, 64, 8, 8])
Node_5 In:
torch.Size([32, 64, 16, 16])
torch.Size([32, 64, 16, 16])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_5 Out: 1 x torch.Size([32, 64, 8, 8])
Cell_5 Out: torch.Size([32, 256, 8, 8])
Cell_6:========================
Cell_6 In: torch.Size([32, 128, 16, 16]) torch.Size([32, 256, 8, 8])
Preproc0_in: torch.Size([32, 128, 16, 16]), Preproc1_in: torch.Size([32, 256, 8, 8])
Preproc0_out: torch.Size([32, 64, 8, 8]), Preproc1_out: torch.Size([32, 64, 8, 8])
Node_0 In: 1 x torch.Size([32, 64, 8, 8])
Node_0 Out: 1 x torch.Size([32, 64, 8, 8])
Node_1 In: 1 x torch.Size([32, 64, 8, 8])
Node_1 Out: 1 x torch.Size([32, 64, 8, 8])
Node_2 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_2 Out: 1 x torch.Size([32, 64, 8, 8])
Node_3 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_3 Out: 1 x torch.Size([32, 64, 8, 8])
Node_4 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_4 Out: 1 x torch.Size([32, 64, 8, 8])
Node_5 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_5 Out: 1 x torch.Size([32, 64, 8, 8])
Cell_6 Out: torch.Size([32, 256, 8, 8])
Cell_7:========================
Cell_7 In: torch.Size([32, 256, 8, 8]) torch.Size([32, 256, 8, 8])
Preproc0_in: torch.Size([32, 256, 8, 8]), Preproc1_in: torch.Size([32, 256, 8, 8])
Preproc0_out: torch.Size([32, 64, 8, 8]), Preproc1_out: torch.Size([32, 64, 8, 8])
Node_0 In: 1 x torch.Size([32, 64, 8, 8])
Node_0 Out: 1 x torch.Size([32, 64, 8, 8])
Node_1 In: 1 x torch.Size([32, 64, 8, 8])
Node_1 Out: 1 x torch.Size([32, 64, 8, 8])
Node_2 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_2 Out: 1 x torch.Size([32, 64, 8, 8])
Node_3 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_3 Out: 1 x torch.Size([32, 64, 8, 8])
Node_4 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_4 Out: 1 x torch.Size([32, 64, 8, 8])
Node_5 In:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node pre_Out:
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
torch.Size([32, 64, 8, 8])
Node_5 Out: 1 x torch.Size([32, 64, 8, 8])
Cell_7 Out: torch.Size([32, 256, 8, 8])
CNN Out: torch.Size([32, 10])
離散網絡結構
每個Node取結構參數最大的2個操作,構造離散的網絡結構
// epoch_49.json
{
"normal_n2_p0": "sepconv3x3",
"normal_n2_p1": "sepconv3x3",
"normal_n2_switch": [
"normal_n2_p0",
"normal_n2_p1"
],
"normal_n3_p0": "skipconnect",
"normal_n3_p1": "sepconv3x3",
"normal_n3_p2": [],
"normal_n3_switch": [
"normal_n3_p0",
"normal_n3_p1"
],
"normal_n4_p0": "sepconv3x3",
"normal_n4_p1": "skipconnect",
"normal_n4_p2": [],
"normal_n4_p3": [],
"normal_n4_switch": [
"normal_n4_p0",
"normal_n4_p1"
],
"normal_n5_p0": "skipconnect",
"normal_n5_p1": "skipconnect",
"normal_n5_p2": [],
"normal_n5_p3": [],
"normal_n5_p4": [],
"normal_n5_switch": [
"normal_n5_p0",
"normal_n5_p1"
],
"reduce_n2_p0": "maxpool",
"reduce_n2_p1": "avgpool",
"reduce_n2_switch": [
"reduce_n2_p0",
"reduce_n2_p1"
],
"reduce_n3_p0": "maxpool",
"reduce_n3_p1": [],
"reduce_n3_p2": "skipconnect",
"reduce_n3_switch": [
"reduce_n3_p0",
"reduce_n3_p2"
],
"reduce_n4_p0": [],
"reduce_n4_p1": [],
"reduce_n4_p2": "skipconnect",
"reduce_n4_p3": "skipconnect",
"reduce_n4_switch": [
"reduce_n4_p2",
"reduce_n4_p3"
],
"reduce_n5_p0": [],
"reduce_n5_p1": "avgpool",
"reduce_n5_p2": "skipconnect",
"reduce_n5_p3": [],
"reduce_n5_p4": [],
"reduce_n5_switch": [
"reduce_n5_p1",
"reduce_n5_p2"
]
}
Conclusion
-
提出了DARTS,一種簡單高效的CNN和RNN 結構搜索算法,並達到了SOTA
-
較之前的方法的效率提高了幾個數量級
未來改進:
- 連續結構編碼與離散搜索之間的差異
- 基於參數共享的方法?
Summary
Reference
【論文筆記】DARTS: Differentiable Architecture Search
論文筆記:DARTS: Differentiable Architecture Search