【DARTS】2019-ICLR-DARTS: Differentiable Architecture Search-論文閱讀


DARTS

2019-ICLR-DARTS Differentiable Architecture Search

來源:ChenBong 博客園


問題

&& 更新結構參數α時, 有用到指數移動平均EMA嗎?

沒有


&& op的padding操作, 是先padding再卷積, 還是先卷積再padding?

先padding再卷積


&& FactorizedReduce() 函數的作用?

將feature map縮小為原來的一半


&& Reduction Cell的哪個Node的Stride=2? Reduction Cell中Node的具體輸入輸出?

不是reduction cell中的node 的stride=2,而是reduction cell的預處理的stride=2,具體見離散網絡結構 部分


&& Cell_3 Node_0 的size預處理是什么?

# 如果[k-1] cell 是reduction cell, 當前cell的input size=[k-1] cell 的 output size, 因此不匹配[k-2] cell 的 output size
# 因此[k-2] cell 的output需要 reduce 處理
if reduction_p:
 # 如果[k-1] cell 是reduction cell: 將feature map縮小為原來的一半
 # input node_0: 處理[k-2]cell的output
 self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
 else:
     # 如果[k-1] cell 不是reduction cell: 標准1x1卷積
     # input node_0: 處理[k-2]cell的output
     self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)

&& α/w的更新, 是以batch為單位還是epochs為單位?

以batch為單位


&& 更新α用的優化器是什么? 具體參數? 具體操作?

self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, 
                                   betas=(0.5, 0.999),weight_decay=1.0E-3)

&& 實際上權重的更新時怎么做的? 只更新一步嗎?

一階近似時,更新一次;

二階近似時,


&& 用val set 更新α, 用train set 更新w, 數據集划分?

val set 為 cifar10 的 test set


Introduction


Motivation

之前的NAS方法:

  • 高昂的計算代價:2000/3000 GPU days
  • 離散的搜索空間,導致大量的結構需要評估

Contribution

  • 基於梯度下降的可微分方法
  • 可以用在CNN和RNN上
  • 在CIFAR-10和PTB數據集上達到SOTA
  • 高效性:2000 GPU days vs 4 GPU days
  • 可遷移性:在cifar10上搜索的結構遷移到ImageNet上,在PTB上搜索的結構遷移到WikiText-2上

Method

搜索空間

搜索cell結構作為最終網絡結構的構建塊(building block)

搜素到的cell可以堆疊構成CNN或者RNN

一個cell是一個包含N個節點的有向無環圖(DAG)

image-20200524185550276

圖1說明:

圖1表示一個cell結構;每個節點都會連接到比自身編號小的節點上;

節點 i 表示feature maps(\(x^{(i)}\)),節點之間不同顏色的箭頭表示不同op,每個op都有自己的權重;

節點之間的操作選自op集O, 兩個節點之間的op數=|O|;

節點 i, j 之間的每個op都對應一個結構參數(\(α^{(i, j)}\))(可以理解為該op的強度/權重等),\(α^{(i,j)}\) 是一個|O|維的向量;

\(x^{(j)}=\sum_{i<j} o^{(i, j)}\left(x^{(i)}\right) \qquad (1)\)

公式(1)說明:

  • \(x^{(i)}\) 表示第i個節點的feature map
  • \(o^{(i, j)}\) 是一組op集合
  • \(o^{(i, j)}\left(x^{(i)}\right)\) 表示對feature map \(x^{(i)}\) 執行op集 \(o^{(i, j)}\) 得到新的feature map
  • 對所有小於j的節點i,都執行 \(o^{(i, j)}\left(x^{(i)}\right)\) ,並將結果求和,得到 j 節點的feature map

\(\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) \qquad (2)\)

公式(2)說明:

  • 向量 \(α^{ij}\) 的維度是|O|
  • \(α^{(i, j)}\) 執行softmax,得到softmax后的結構參數 \(\hat{α}^{(i, j)}=\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}\)
  • 將op集合O中的每個op都施加在x上,並乘以對應的結構參數 \(\hat{α}^{(i, j)}\) ,再求和,得到 \(\bar{o}^{(i, j)}(x)\)
  • 則mix op 記為 \(\bar{o}^{(i, j)}(·)\)

&& 兩個node之間, 不同顏色的op的output feature map size 不一樣怎么辦?

兩個node之間, input size相同, 由於op類型不同, 會導致不同op輸出的output size不同, 代碼中是通過padding來保持不同op的output feature map維度統一的

image-20200807222744639


&& 兩個node之間, 不同顏色的op的output feature maps 是如何整合的? 求和還是concat?

不同op的的output feature maps (通道數和size都相同) 會進行求和 (對應位置元素相加), 因此多個op的output feature maps 整合后, feature map的通道數和size都不變


&& 來自不同node的 output feature maps 如何整合?求和還是concat?

求和


\(o^{(i, j)}=\operatorname{argmax}_{o \in \mathcal{O}} \alpha_{o}^{(i, j)}\)

公式說明:

在搜索的最后階段三條不同顏色的線會保留對應結構參數 \(α^{(i, j)}\) 最大的那一條

image-20200524185838715


結構圖例說明

**CNN cell結構: **

image-20200808160326838

其中每個三角形代表圖1中兩個node之間的一組操作, 即每個三角形表示公式(2)的操作: \(\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) \qquad (2)\) , 到最后, 圖1中的每組線會只保留一個op, 即每個三角形到最后也只保留1個強度最大的op. 而且每個node會選擇n個op中強度最大的2個.

image-20200808160542699

最后每個三角形(對應圖1中兩個node之間的一組操作)

CNN結構:

666

一個三角形表示圖1中兩個node之間的一組op:

\(\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) \qquad (2)\)

image-20200808160542699

優化目標

我們的目標是聯合學習結構參數(α)和網絡權重(w):

\(\min _{\alpha} \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \qquad (3)\)

s.t. \(\quad w^{*}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{\text {train}}(w, \alpha) \qquad (4)\)

公式(3)(4)說明:

  • \(w^{*}(\alpha)\) 是結構參數取值為 \(α\) 時最佳的網絡權重,即不同的 \(α\) 對應不同的最佳網絡權重 \(w^{*}(\alpha)\)
  • 訓練流程:
    • 每次改變 \(α\) ,先將網絡權重訓練到對應的最佳網絡權重 \(w^{*}(\alpha)\) ,——公式(4)
    • 對結構參數 \(α\) 梯度下降,嘗試不同的結構參數 \(α\) ,找到使得loss最小的結構參數 \(α\),即找到了最佳的結構——公式(3)

image-20200524185951756

算法(1)DARTS-可微分的結構搜索 說明:

  • 根據結構參數 \(α^{(i, j)}\) 構建mix op \(\bar{o}^{(i, j)}(·)\)
  • 若(還未收斂),執行:
    • 梯度下降更新結構參數 \(α\)\(\nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right)\)
    • 梯度下降更新網絡權重 w: \(\nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha)\)
  • 根據收斂后的 \(α\) 導出最終結構

近似處理

每次更新 \(α\) 后,如果重新訓練網絡權重到收斂,需要消耗大量時間,我們希望通過簡化公式(3),只更新一次,來近似逼近 \(w^{*}(\alpha)\) ,而不是通過訓練到收斂來獲得 \(w^{*}(\alpha)\)

\(\nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \qquad (5)\)

\((5)\approx \nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right) \qquad(6)\)

公式(6) 說明:

\(\xi\) 是一個超參, 代表權重一步優化的學習率, 在結構參數 \(α\) 更新后, 該公式通過只更新一步網絡權重(本來要更新到收斂), 來近似收斂后的網絡權重. 注意到如果w已經是最優時, 即 \(\nabla_{w} \mathcal{L}_{\text {train}}(w, \alpha)=0\) 時, (6)將退化為 \(\nabla_{\alpha} \mathcal{L}_{v a l}(w, \alpha)\)


\((6)=\nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)-\xi \nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \qquad (7)\)

(7)式中的 \(w^{\prime}=w-\xi \nabla_{w} L_{t r a i n}(w, \alpha)\)

公式(7)說明:

式(6)應用鏈式法則, 可得式(7)

&& (7)式后面包含了一個計算復雜度很高的矩陣乘法 $\xi \nabla_{\alpha, w}^{2} \mathcal{L}{t r a i n}(w, \alpha) \nabla{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) $ ,文中提出有限差分近似的方法解決, 如下.


\(w^{\pm}=w \pm \epsilon \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)\) 則:

\(\nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \approx \frac{\nabla_{\alpha} \mathcal{L}_{t r a i n}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{t r a i n}\left(w^{-}, \alpha\right)}{2 \epsilon} \qquad (8)\)

公式(8)說明:

&& 評估該 有限差分 僅需要兩次前向傳播即可得到 weights,兩次反向傳播,就可以得到 α,運算復雜度大大的降低了: \(O(|\alpha||w|)\) to \(O(|\alpha|+|w|)\)

理論上 \(\epsilon\) 要足夠小, 經驗上取 \(\epsilon=0.01 /\left\|\nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)\right\|_{2}\) 即可足夠精確.


\(\xi\) 取值討論:

  • \(\xi = 0\) , 式(7)中的二階導數將消失, 因此
    • \(\xi = 0\) , 此時為一階近似
    • \(\xi > 0\) , 此時為二階近似, 這種情況下, 簡單的策略是將 \(\xi\) 設置為網絡權重w的學習率

\(\xi\) 取值實驗:

設置簡單的損失函數:

  • \(\mathcal{L}_{\text {val}}(w, \alpha)=\alpha w-2 \alpha+1\)
  • \(\mathcal{L}_{\text {train}}(w, \alpha)=w^2-2\alpha w+ \alpha^2\)

\(\xi\) 取不同的值, 優化過程如下圖:

image-20200524190209578

image-20200524190443350


連續結構=>離散結構

為了構造離散的結構的cell中的每個節點(即邊上不存在結構參數 或者說 結構參數均為1),對於每個節點,我們都保留op強度最強的k個邊,對於CNN來說k=2,對於RNN來說k=1。

即下圖中,CNN cell 的每個node 都有k=2個輸入,RNN cell 的每個node 都有k=1個輸入。

&& 代碼中如何實現?

&& 堆疊cell以后, 多個cell是否是相同的? 如何實現?

op強度定義為: \(\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}\)

image-20200524190637100


Experiments

以下是op集 O 中的op:

  • 3 × 3 and 5 × 5 separable convolutions,
  • 3 × 3 and 5 × 5 dilated separable convolutions,
  • 3 × 3 max pooling,
  • 3 × 3 average pooling,
  • identity (skip connection?)
  • zero.

所有的op:

  • stride = 1(如有需要的話)
  • 不同操作的feature map(分辨率可能不同)都進行pad以保持相同的分辨率

我們使用:

  • 對於卷積操作,使用 ReLU-Conv-BN的順序
  • 每個可分離卷積都應用兩次
  • CNN cell包含N=7個Nodes,output node定義為所有中間節點(feature maps)的concat

&& concat維度不同如何處理?

image-20200524190900372

每個cell包含2個input node,和1個output node

  • 第k個cell 的 2個input node 分別等於 第k-2個cell 和 第 k-1 個cell的output node
  • 位於網絡深度 1/3 和 2/3 的2個cell,設置為reduction cell,即cell中的op 的stride=2
  • 因此有2種不用的cell,分別稱為Normal cell 和 Reduce cell,兩種cell的結構參數不同,分別稱為 \(α_{normal}, α_{reduce}\)
  • 其中 \(α_{normal}\) 在所有 Normal cell 中共享,\(α_{reduce}\) 在所有 Reduce cell 中共享

  • 為了確定最終的結構,我們用不同的 random seeds運行DARTS 4次,並將4次的結果train from scratch 少量epochs(100 epochs for CIFAR-10,300 epochs for PTB),根據訓練少量epochs后的性能來挑選最佳cell
  • 由於cell要進行多次堆疊,因此運行多次搜索是必要的,而且結果可能是初始值敏感的,如下圖2,4:

image-20200524191124715


結構評估

為了評估搜索到的結構,我們隨機初始化結構的權重(在搜索過程中學習的權重被拋棄),train from scratch,並報告了其在測試集上的權重。


image-20200524191204869


image-20200524191223691


結果分析

image-20200524191305687

圖3說明:

  • DARTS在減少3個數量級的計算量的基礎上達到了與SOTA相當的結果
    • (i.e. 1.5 or 4 GPU days vs 2000 GPU days for NASNet and 3150 GPU days for AmoebaNet)
  • 較長的搜索時間是由於我們對cell 的選擇重復搜索了4次,這種做法對CNN cell 來說不是特別重要,CNN cell 的初值敏感性較不明顯,RNN cell 的初值敏感性較大

image-20200524191326038

表1說明:

  • 從表1可以看出,隨機搜索的結果也具有競爭力,說明本方法搜索空間設計的較好。

image-20200524191354234

表3說明:

  • 在cifar10上搜索的cell,確實可以被遷移到ImageNet上。

image-20200524191620480

表4說明:

  • 表4中可看出,PTB與WT2之間的可遷移性較弱(與CIFAR-10和ImageNet的可遷移性相比),原因是用於搜索結構的源數據集(PTB)規模較小
  • 可以直接對感興趣的數據集進行結構搜索,可以避免遷移性的問題

搜索過程中網絡輸入輸出的變化

CNN:==================================================================
CNN In:  torch.Size([32, 3, 32, 32])
CNN stem In : torch.Size([32, 3, 32, 32])
CNN stem Out: torch.Size([32, 48, 32, 32]), torch.Size([32, 48, 32, 32])

Cell_0:========================
Cell_0 In:  torch.Size([32, 48, 32, 32]) torch.Size([32, 48, 32, 32])

Preproc0_in:  torch.Size([32, 48, 32, 32]), Preproc1_in: torch.Size([32, 48, 32, 32])
Preproc0_out: torch.Size([32, 16, 32, 32]), Preproc1_out: torch.Size([32, 16, 32, 32])

  Node_0 In:    1 x torch.Size([32, 16, 32, 32])
  Node_0 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_1 In:    1 x torch.Size([32, 16, 32, 32])
  Node_1 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_2 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_2 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_3 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_3 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_4 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_4 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_5 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_5 Out:   1 x torch.Size([32, 16, 32, 32])

Cell_0 Out: torch.Size([32, 64, 32, 32])

Cell_1:========================
Cell_1 In:  torch.Size([32, 48, 32, 32]) torch.Size([32, 64, 32, 32])

Preproc0_in:  torch.Size([32, 48, 32, 32]), Preproc1_in: torch.Size([32, 64, 32, 32])
Preproc0_out: torch.Size([32, 16, 32, 32]), Preproc1_out: torch.Size([32, 16, 32, 32])

  Node_0 In:    1 x torch.Size([32, 16, 32, 32])
  Node_0 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_1 In:    1 x torch.Size([32, 16, 32, 32])
  Node_1 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_2 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_2 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_3 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_3 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_4 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_4 Out:   1 x torch.Size([32, 16, 32, 32])

  Node_5 In:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node pre_Out:
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
     torch.Size([32, 16, 32, 32])
  Node_5 Out:   1 x torch.Size([32, 16, 32, 32])

Cell_1 Out: torch.Size([32, 64, 32, 32])

Cell_2:========================
Cell_2 In:  torch.Size([32, 64, 32, 32]) torch.Size([32, 64, 32, 32])

Preproc0_in:  torch.Size([32, 64, 32, 32]), Preproc1_in: torch.Size([32, 64, 32, 32])
Preproc0_out: torch.Size([32, 32, 32, 32]), Preproc1_out: torch.Size([32, 32, 32, 32])

  Node_0 In:    1 x torch.Size([32, 32, 32, 32])
  Node_0 Out:   1 x torch.Size([32, 32, 32, 32])

  Node_1 In:    1 x torch.Size([32, 32, 32, 32])
  Node_1 Out:   1 x torch.Size([32, 32, 32, 32])

  Node_2 In:
     torch.Size([32, 32, 32, 32])
     torch.Size([32, 32, 32, 32])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_2 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_3 In:
     torch.Size([32, 32, 32, 32])
     torch.Size([32, 32, 32, 32])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_3 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_4 In:
     torch.Size([32, 32, 32, 32])
     torch.Size([32, 32, 32, 32])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_4 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_5 In:
     torch.Size([32, 32, 32, 32])
     torch.Size([32, 32, 32, 32])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_5 Out:   1 x torch.Size([32, 32, 16, 16])

Cell_2 Out: torch.Size([32, 128, 16, 16])

Cell_3:========================
Cell_3 In:  torch.Size([32, 64, 32, 32]) torch.Size([32, 128, 16, 16])

Preproc0_in:  torch.Size([32, 64, 32, 32]), Preproc1_in: torch.Size([32, 128, 16, 16])
Preproc0_out: torch.Size([32, 32, 16, 16]), Preproc1_out: torch.Size([32, 32, 16, 16])

  Node_0 In:    1 x torch.Size([32, 32, 16, 16])
  Node_0 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_1 In:    1 x torch.Size([32, 32, 16, 16])
  Node_1 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_2 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_2 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_3 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_3 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_4 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_4 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_5 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_5 Out:   1 x torch.Size([32, 32, 16, 16])

Cell_3 Out: torch.Size([32, 128, 16, 16])

Cell_4:========================
Cell_4 In:  torch.Size([32, 128, 16, 16]) torch.Size([32, 128, 16, 16])

Preproc0_in:  torch.Size([32, 128, 16, 16]), Preproc1_in: torch.Size([32, 128, 16, 16])
Preproc0_out: torch.Size([32, 32, 16, 16]), Preproc1_out: torch.Size([32, 32, 16, 16])

  Node_0 In:    1 x torch.Size([32, 32, 16, 16])
  Node_0 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_1 In:    1 x torch.Size([32, 32, 16, 16])
  Node_1 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_2 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_2 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_3 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_3 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_4 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_4 Out:   1 x torch.Size([32, 32, 16, 16])

  Node_5 In:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node pre_Out:
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
     torch.Size([32, 32, 16, 16])
  Node_5 Out:   1 x torch.Size([32, 32, 16, 16])

Cell_4 Out: torch.Size([32, 128, 16, 16])

Cell_5:========================
Cell_5 In:  torch.Size([32, 128, 16, 16]) torch.Size([32, 128, 16, 16])

Preproc0_in:  torch.Size([32, 128, 16, 16]), Preproc1_in: torch.Size([32, 128, 16, 16])
Preproc0_out: torch.Size([32, 64, 16, 16]), Preproc1_out: torch.Size([32, 64, 16, 16])

  Node_0 In:    1 x torch.Size([32, 64, 16, 16])
  Node_0 Out:   1 x torch.Size([32, 64, 16, 16])

  Node_1 In:    1 x torch.Size([32, 64, 16, 16])
  Node_1 Out:   1 x torch.Size([32, 64, 16, 16])

  Node_2 In:
     torch.Size([32, 64, 16, 16])
     torch.Size([32, 64, 16, 16])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_2 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_3 In:
     torch.Size([32, 64, 16, 16])
     torch.Size([32, 64, 16, 16])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_3 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_4 In:
     torch.Size([32, 64, 16, 16])
     torch.Size([32, 64, 16, 16])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_4 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_5 In:
     torch.Size([32, 64, 16, 16])
     torch.Size([32, 64, 16, 16])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_5 Out:   1 x torch.Size([32, 64, 8, 8])

Cell_5 Out: torch.Size([32, 256, 8, 8])

Cell_6:========================
Cell_6 In:  torch.Size([32, 128, 16, 16]) torch.Size([32, 256, 8, 8])

Preproc0_in:  torch.Size([32, 128, 16, 16]), Preproc1_in: torch.Size([32, 256, 8, 8])
Preproc0_out: torch.Size([32, 64, 8, 8]), Preproc1_out: torch.Size([32, 64, 8, 8])

  Node_0 In:    1 x torch.Size([32, 64, 8, 8])
  Node_0 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_1 In:    1 x torch.Size([32, 64, 8, 8])
  Node_1 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_2 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_2 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_3 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_3 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_4 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_4 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_5 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_5 Out:   1 x torch.Size([32, 64, 8, 8])

Cell_6 Out: torch.Size([32, 256, 8, 8])

Cell_7:========================
Cell_7 In:  torch.Size([32, 256, 8, 8]) torch.Size([32, 256, 8, 8])

Preproc0_in:  torch.Size([32, 256, 8, 8]), Preproc1_in: torch.Size([32, 256, 8, 8])
Preproc0_out: torch.Size([32, 64, 8, 8]), Preproc1_out: torch.Size([32, 64, 8, 8])

  Node_0 In:    1 x torch.Size([32, 64, 8, 8])
  Node_0 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_1 In:    1 x torch.Size([32, 64, 8, 8])
  Node_1 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_2 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_2 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_3 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_3 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_4 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_4 Out:   1 x torch.Size([32, 64, 8, 8])

  Node_5 In:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node pre_Out:
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
     torch.Size([32, 64, 8, 8])
  Node_5 Out:   1 x torch.Size([32, 64, 8, 8])

Cell_7 Out: torch.Size([32, 256, 8, 8])

CNN Out:  torch.Size([32, 10])

離散網絡結構

每個Node取結構參數最大的2個操作,構造離散的網絡結構

// epoch_49.json
{
  "normal_n2_p0": "sepconv3x3",
  "normal_n2_p1": "sepconv3x3",
  "normal_n2_switch": [
    "normal_n2_p0",
    "normal_n2_p1"
  ],
  "normal_n3_p0": "skipconnect",
  "normal_n3_p1": "sepconv3x3",
  "normal_n3_p2": [],
  "normal_n3_switch": [
    "normal_n3_p0",
    "normal_n3_p1"
  ],
  "normal_n4_p0": "sepconv3x3",
  "normal_n4_p1": "skipconnect",
  "normal_n4_p2": [],
  "normal_n4_p3": [],
  "normal_n4_switch": [
    "normal_n4_p0",
    "normal_n4_p1"
  ],
  "normal_n5_p0": "skipconnect",
  "normal_n5_p1": "skipconnect",
  "normal_n5_p2": [],
  "normal_n5_p3": [],
  "normal_n5_p4": [],
  "normal_n5_switch": [
    "normal_n5_p0",
    "normal_n5_p1"
  ],
  "reduce_n2_p0": "maxpool",
  "reduce_n2_p1": "avgpool",
  "reduce_n2_switch": [
    "reduce_n2_p0",
    "reduce_n2_p1"
  ],
  "reduce_n3_p0": "maxpool",
  "reduce_n3_p1": [],
  "reduce_n3_p2": "skipconnect",
  "reduce_n3_switch": [
    "reduce_n3_p0",
    "reduce_n3_p2"
  ],
  "reduce_n4_p0": [],
  "reduce_n4_p1": [],
  "reduce_n4_p2": "skipconnect",
  "reduce_n4_p3": "skipconnect",
  "reduce_n4_switch": [
    "reduce_n4_p2",
    "reduce_n4_p3"
  ],
  "reduce_n5_p0": [],
  "reduce_n5_p1": "avgpool",
  "reduce_n5_p2": "skipconnect",
  "reduce_n5_p3": [],
  "reduce_n5_p4": [],
  "reduce_n5_switch": [
    "reduce_n5_p1",
    "reduce_n5_p2"
  ]
}

Conclusion

  • 提出了DARTS,一種簡單高效的CNN和RNN 結構搜索算法,並達到了SOTA

  • 較之前的方法的效率提高了幾個數量級


未來改進:

  • 連續結構編碼與離散搜索之間的差異
  • 基於參數共享的方法?

Summary


Reference

【論文筆記】DARTS: Differentiable Architecture Search

論文筆記:DARTS: Differentiable Architecture Search


PyTorch 中的 ModuleList 和 Sequential: 區別和使用場景

DARTS代碼分析

nni-Search Space-Mutable

nni-Mutable


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM