Optimizing Federated Learning on Non-IID Data with Reinforcement Learning 筆記


  • 閱讀論文 Optimizing Federated Learning on Non-IID Data with Reinforcement Learning 的筆記
  • 如有侵權,請聯系作者,將會撤銷發布。

主要講什么

  • 提出FAVOR,一個經驗驅動控制的框架。
  • 智能的選擇客戶端設備來參與聯邦學習中每一輪訓練,以抵消數據非獨立同分布帶啊來的偏差,並提升收斂的速度。
  • 使用了deep Q-learning 來學習如何選擇每輪參與訓練的客戶端以最大化一個 鼓勵提升正確率並處罰使用更多通信次數的 獎勵。

Intro

  • 一般聯邦學習都是直接隨機選取一部分設備參與每輪的訓練,以避免由於不穩定的網絡狀況和straggler設備造成的長尾(long-tailed)等待時間
  • FedAvg可能會嚴重的降低模型的准確性和收斂所需的通信次數
    • 而且由於數據非獨立同分布,聚合這些不同的模型可能會減慢收斂,並且會降低模型准確性
    • 一個設備中的訓練數據的分布和訓練得到的模型參數之間有內含的聯系

這篇文章提出的目標

FAVOR的目標

  • 通過學習積極地在每輪選擇最好的,可以抵消非獨立同分布會帶來的偏差的設備集,以加速並穩定聯邦學習過程。

選擇設備

  • 用本地模型參數和共享的全局模型作為狀態,從而公平地?選擇可能對全局模型有所提升的設備
  • 使用基於DQN的強化學習來提高效率和魯棒性。(在FL的設備選擇環節中使用基於DQN的強化學習)

壓縮模型參數

  • 提出了一個可以壓縮模型參數以對狀態空間降維
  • apply principle component analysis(PCA) to model weights and use the compressed model weights to represent states instead.
  • 只根據在第一輪訓練(step 2中得到的)的本地模型的參數來計算PCA
  • # TODO: 看不懂源碼,看不懂過程QAQ

非獨立同分布的挑戰

  • 論文中用實驗來展現:
    • 如果隨機選取設備,那么非獨立同分布的數據可能會減慢聯邦學習的收斂速度。
    • 用cluster 算法可以有助於平衡數據分布並加快收斂。

實驗過程

  1. 100個設備下載最初的Global weights(隨機生成的)然后根據本地數據執行一個epoch的SGD,獲得\(w_1^{(k)}\)
  2. \(w_1^{(k)}\)執行K-Center算法,對100個設備進行聚類,分成了10個組。
  3. 每個組里面隨機選擇一個設備進行聯邦學習。
  • 結果:

  • 這個實驗說明了:通過仔細選擇每輪參與訓練的設備可以提高聯邦學習的性能。

用DRL來選擇客戶端

Agent 基於 Deep Q-Network

  • 用DQN來選擇k個最合適的設備來參與訓練
  • 通過一個網絡來學習得到\(Q^*(s_t,a)\),選擇\(Q^*\)最大的k個設備來訓練。
  • 因為設備中數據非獨立同分布的原因,直接隨機選擇設備來訓練效果會不好,所以用這個DQN可以根據每個設備中的模型參數來訓練,得到一個選擇設備的策略。
  • \(s_t=(w_t,w_t^{(1)},...,w_t^{(N)})\)
  • \(a\): action space為{1,2,...,N}, a=1指選擇設備i去參與FL訓練
  • DQN agent 被訓練為要最大化cumulative discounted reward (即R) 的期望。:
    • reward: \(r_t=\Xi ^{(w_t-\Omega)}-1\)
      • \(w_t\): 在第 t 輪結束后,對held-out validation set(保留驗證集)上的測試得出的准確度
      • \(\Omega\): 目標准確度
      • \(\Xi ^{(w_t-\Omega)}\): 激勵agent去選擇能取得更高准確度\(w_t\)的設備
        • 由於通常隨着在機器學習進行,模型准確度的增長速度會變慢,也就是隨着t增加,\(|w_t-w_{t+1}|\)會減小。
        • 所以用這樣的指數項來放大FL過程靠后階段中微小的准確度的增長。
        • \(\Xi\): 一個正常數,論文中的實驗設置為了64
      • -1:激勵 agent 用更少的訓練輪數 (?)
    • \(R=\sum_{t=1}^{T}\gamma ^{t-1}r_t\)
    • \(w_t=\Omega, r_t == 0\) 時,聯邦學習結束

FAVOR過程

  1. N個可行的設備向FL server報到
    1. 每個設備都從server上下載最初的隨機獲得的模型參數\(w_{init}\)
    2. 用 local SGD 訓練一個epoch,然后將訓練得到的模型參數\(\{w_1^{(k)},k \in [N]\}\)傳給FL server
    1. 接收到上傳的local weights后,對應在server上存的local weights更新
    2. DQN agent 計算所有設備的\(Q(s_t,a;\theta)\)
    1. DQN agent 根據\(Q(s_t,a;\theta)\)的大小,選擇k個最大Q值對應的k個設備。
    2. 被選中的k個設備下載最新的global model weights \(w_t\), 並執行一個epoch的local SGD以獲得\(\{w_{t+1}^{(k)}|k \in [K]\}\)
  2. \(\{w_{t+1}^{(k)}|k \in [K]\}\)被傳到server,以使用FEDAVG計算\(w_{t+1}\)。重復3-5步直到結束(如達到目標准確率,或者 訓練了一定數量的rounds)。
  • 論文作者GitHub上還沒有給出這部分的代碼。

用PCA降維

  • 對模型參數使用PCA,然后用壓縮后的模型參數來表示states。
  • 看不懂這部分代碼

用Double DQN 訓練Agent

  • 使用DDQN來學習函數\(Q^*(s_t,a)\)
  • 原來的Q-Learning可能會不穩定
  • 而DDQN加入了另一個value function \(Q(s,a;\theta_t')\),這樣可以使action-value函數的估計更加穩定


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM