【論文解讀】Federated Learning of Deep Networks using Model Averaging 模型平均下的深度網絡聯邦學習


一、闡述了聯邦學習的誕生背景:

在當前數據具有價值,並且需要被保護,數據分布為non-IID情況下,需要提出一個框架來進行行之有效的訓練,這也是聯邦學習誕生的原因;

 

二、論文的相關工作:

首先,論文闡述了聯邦學習所適用的領域:

1.數據集應該具有較大隱私,所以無法上傳;

2.對於有監督學習下的任務,可以很輕易地判斷其標簽;

 

隨后,論文舉了兩個基本例子:

1.典型的圖像分類:根據學習用戶以往的瀏覽照片類型來判斷能夠查詢哪些照片;

2.典型的語言模型:典型的詞語預測系統,通過以往的記錄進行分析;

這兩個例子和論文中所提到的適用領域不謀而合,因為:

1.類型可以通過用戶標記來定義;

2.對於不同用戶的習慣,數據分布很可能不同。例如官方語言和俚語,Flickr照片和手機照片;

 

后續論文通過這兩個問題使用兩種不同的網絡模型來進行測試。

圖像分類——前饋深層網絡;

語言模型——LSTM;

通過這兩個實驗來進行聯邦學習隱私又是和降低大型數據通信成本方面的測試;

 

最后,論文也論證了聯邦優化較於分布式優化的差別和實際會遇到的問題。

主要來說,聯邦優化和分布式優化的區別在於數據集上:

1.聯邦優化數據集為non-IID,任何一個節點的數據集分布無法代表整體;

2.數據分布不平均,有的節點數據集多,有的節點數據集少,也就是不平衡屬性;

而我們實際中會遇到的問題(本篇論文討論理想狀態下,並不做過多涉及):

1.客戶端數據回添加或者刪除;

2.客戶端有可能不發送數據或者發的數據有問題;

3.客戶端可用性可能會因為數據的分布的不同形式而受到影響(美式英語語音和英式英語語音可能會混雜在一起);

 

三、詳細公式推導過程以及算法流程:

一般性的前提條件:

假設客戶結點固定,為K個,並且每個客戶節點有固定的本地數據集。每一輪開始,選擇客戶的隨即分數C,將服務器當前的全局參數發送給每個客戶端,每個客戶端基於全局狀態+本地數據集進行本地計算,將更新發送給服務器,服務器更新並且用於全局狀態,重復該過程。

 

基本的數學表示:

對於整個學習下的目標函數,應該為

其中,L代表loss函數。

 

這里稍微提一下,自己初次看有點懵逼,現在發現是把神經網絡全部忘光了。


 

【補充】:

這是一個典型的神經網絡的的非凸損失函數。

其中n代表樣本點哥鼠,fi(w)代表根據每個樣本點i算出來的損失函數,最后進行整個求均值,個人以依稀記得每輪計算整個數據集的f(w),來進行更新;


 

而對於 聯邦學習下的目標函數,也是神經網絡下的改版,其實就是區分了個結點而已。

 

 其中,K為節點個數總共為K個。nk代表結點k中的訓練集樣本個數;

所以總而言之就是把每個epoch的目標函數變成每個節點的目標函數的加和。對於單個結點內的局部更新,和上面的神經網絡的目標函數不變。

 

具體算法的步驟:

總體還是用的神經網絡那一套,使用每個epoch所得到的w矩陣參數來更新下一個epoch。

只不過設計一個全局w的問題;

 

 如上圖所示;

其實思路很簡單:

1.通過在每一個批次中選取部分節點,進行一個epoch的訓練,之后每個結點上傳服務器。

2.服務器將所有的w,進行加和求均值得到新的w,在下發給每個結點。

3.每個結點將下發的結點代替上一個epoch算出的w,進行新的epoch的訓練。

重復上述三步直到服務器確定w收斂為止。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM