1. 背景介紹
-
系統部署
移動手機和可穿戴設備是現代十分常見的數據產生設備。這些設備每天都會產生巨量的各種形式的數據。考慮到算力需求,數據傳輸以及個人隱私的限制,系統部署越來越傾向於在本地存儲數據,模型計算由邊緣設備完成。 -
數據孤島
數據往往以孤島形式出現。在現實中想要將分散在各地、各個機構的數據進行整合幾乎是不可能的,或者說所需的成本是巨大的。
2. 聯邦學習概念
- 本質:一種分布式機器學習技術,或機器學習框架,讓人工智能系統能夠更加高效、准確的共同使用各自的數據,實現共同建模,提高 AI 模型的效果。
- 公式化定義
經典的聯邦學習問題需要從上百萬的遠程設備中存儲的海量數據里面學習到一個全局統計模型。這個任務可以用以下目標函數來表述:
$ min \quad F(w), \quad where \quad F(w) := \Sigma^{m}_{k=1}{p}_k{F}_k(w) $
- 其中 m 代表設備總量,\(F_k\) 為第 k 個設備的本地目標函數,\(p_k\) 被定義為對應設備的影響權重。
- \(p_k\) 具有性質:\(p_k\) ≥ 0,且 \(\Sigma^{m}_{k=1}{p}_k = 1\)。
- \(F_k\)通常被定義為基於本地數據的經驗風險。
-
聯邦學習與現有研究的區別
聯邦學習中工作節點代表的是模型訓練的數據擁有方,其對本地的數據具有完全自治的權限,可以自主決定何時加入聯邦學習進行建模,相對地在參數服務器中,中心節點始終占據着主導地位。
聯邦學習本質上仍是一種分布式機器學習,所以不完全認同該圖中與分布式機器學習的區別,況且完全可以設計不需要傳輸數據,只傳輸梯度的分布式機器學習。
- 聯邦學習與傳統分布式學習的區分
- 用戶對於自己的設備和有着控制權。傳統分布式往往由 server 控制。
- Worker節點是不穩定的,比如手機可能突然就沒電了,或者進入了電梯突然沒信號了等情況。傳統分布式學習 worker 往往是在機房中,穩定。
- 通信代價往往比計算代價要高。聯邦學習 worker 往往通過無線通信,通信時延大。
- 分布在Worker節點上的數據並不是獨立同分布的(not IID),因此很多已有的減少通信次數的算法就不適用。
- 節點負載不平衡,有的設備數據多有的設備數據少。比如有的用戶幾天拍一張照片有的用戶一天拍好多照片,這給建模帶來了困難。如果給圖片的權重一樣,那么模型可能往往取決於拍圖片多的用戶,拍照少的用戶就被忽略了。如果用戶的權重相同,這樣學出來的模型對拍照多的用戶又不太好了。負載不平衡也給計算帶來了挑戰,數據少的用戶可能一下子算了很多epoc了,數據多的用戶還早着。這一點上,聯邦學習不像傳統的分布式學習可以做負載均衡,即將一個節點的數據轉移到另一個節點。
3. 聯邦學習的分類
把每個參與共同建模的企業稱為參與方,根據多參與方之間數據分布的不同,把聯邦學習分為三類:橫向聯邦學習、縱向聯邦學習與聯邦遷移學習(工業界目前用的少)。
3.1. 橫向聯邦學習
谷歌最初采用橫向聯邦的方式解決安卓手機終端用戶在本地更新模型的問題的。
- 本質是樣本的聯合。適用於參與者間業態相同但觸達客戶不同,即特征重疊多,用戶重疊少時的場景,比如不同地區的銀行間,業務相似(特征相似),用戶不同(樣本不同)。
- 學習過程
step 1:參與方在本地計算模型梯度,然后將梯度結果加密上傳到服務器;
step 2:服務器 A 聚合各用戶的梯度,更新模型參數;
step 3:服務器 A 返回更新后的模型給各參與方;
step 4:各參與方基於加密梯度更新各自模型。 - 步驟解讀
- 傳統機器學習建模時,通常把模型訓練集數據集合到一個數據中心,然后再訓練模型。
- 橫向聯邦學習中,可以看作是基於樣本的分布式模型訓練,分發全部數據到不同的機器,每台機器從服務器下載模型,然后利用本地數據訓練模型,之后返回給服務器需要更新的參數;服務器聚合各機器上的返回的參數,更新模型,再把最新的模型反饋到每台機器。
- 整個過程中每台機器上都是相同且完整的模型,且機器之間不交流不依賴,在預測時每台機器也可以獨立預測。
3.2. 縱向聯邦學習
- 本質是特征的聯合,適用於用戶重疊多,特征重疊少的場景,比如同一地區的商超和銀行,他們之間用戶相同,但業務不同(特征)。
- 學習過程
縱向聯邦學習的本質是交叉用戶在不同業態下的特征聯合,比如商超A和銀行B,在傳統的機器學習建模過程中,需要將兩部分數據集中到一個數據中心,然后再將每個用戶的特征 join成一條數據用來訓練模型,所以就需要雙方有用戶交集(基於join結果建模),並有一方存在label。
-
學習步驟
- 第三方 C 加密樣本對齊。在系統級做這件事,因此在企業感知層面不會暴露非交叉用戶。
- 對齊樣本進行模型加密訓練:
step1:由第三方C向A和B發送公鑰,用來加密需要傳輸的數據;
step2:A和B分別計算和自己相關的特征中間結果,並加密交互,用來求得各自梯度和損失;
step3:A和B分別計算各自加密后的梯度並添加掩碼發送給C,同時B計算加密后的損失發送給C;
step4:C解密梯度和損失后回傳給A和B,A、B去除掩碼並更新模型。
-
步驟解讀
縱向聯邦學習的具體訓練步驟如下:
在整個過程中參與方都不知道另一方的數據和特征,且訓練結束后參與方只得到自己側的模型參數,即半模型。
- 預測過程:
由於各參與方只能得到與自己相關的模型參數,預測時需要雙方協作完成,如下圖所示:
共同建模的結果:
- 雙方均獲得數據保護
- 共同提升模型效果
- 模型無損失
3.3. 聯邦遷移學習
-
適用場景
當參與者特征和樣本重疊都很少時可以考慮使用聯邦遷移學習,如不同地區的商超和銀行間的聯合。 -
定義
- 遷移學習,是指利用數據、任務、或模型之間的相似性,將在源領域學習過的模型,應用於 目標領域的一種學習過程。
- 在兩個數據集的用戶與用戶特征重疊都較少的情況下,我們不對數據進行切分,而可以利用遷移學習來客服數據或標簽不足的情況。
遷移學習的核心是,找到源領域和目標領域之間的相似性,舉一個楊強教授經常舉的例子來說明:我們都知道在中國大陸開車時,駕駛員坐在左邊,靠馬路右側行駛。這是基本的規則。然而,如果在英國、香港等地區開車,駕駛員是坐在右邊,需要靠馬路左側行駛。那么,如果我們從中國大陸到了香港,應該如何快速地適應 他們的開車方式呢?訣竅就是找到這里的不變量:不論在哪個地區,駕駛員都是緊靠馬路中間。這就是我們這個開車問題中的不變量。 找到相似性 (不變量),是進行遷移學習的核心。
- 訓練過程和推理過程
聯邦遷移學習的步驟與縱向聯邦學習相似,只是中間傳遞結果不同(實際上每個模型的中間傳遞結果都不同)。
4. 聯邦學習研究熱點
4.1 Communication Efficiency
並行梯度下降中(parallel gradient descent),第 \({i}\) 個worker執行了任務:
- 從server接收模型參數 \({w}\)
- 根據 \({w}\) 和本地數據計算梯度 \({g_i}\)
- 將 \({g_i}\) 發送給server
然后 server 接收了所有用戶的 \({g_i}\) 之后,執行任務:
- 接收 \({g_1}, {g_2},...,{g_m}\)
- 計算:\({g = \Sigma^m_{i=1}{g_i}}\)
- 做一次梯度下降,更新模型參數:\(w_{i+1} = w_{i} - \alpha \cdot g\)
- 然后將新的參數發送給用戶,等待用戶數據重復執行下一輪迭代
federated averaging algorithm 中,用更少的通信次數達到了收斂。
federated averaging algorithm 中, worker 執行任務:
- 從 server 接收參數 \(w\)
- 迭代一下過程:
a. 根據 \({w}\) 和本地數據計算梯度 \({g}\)
b. 本地化更新:\(w = w - \alpha \cdot g\)- 將 \(\widetilde{w}_{i}=w\) 傳給 server
然后 server 接收了全部 \(\widetilde{w}\) 后,執行更新:
更新 \(w \leftarrow \frac{1}{m}\left(\widetilde{w}_{1}+\cdots+\widetilde{w}_{m}\right)\),下一輪迭代時將此 \(w\) 傳給所有 server
- 其他結論:
- 相同次數的通信, Federated Averaging收斂的更快。兩次通信之間 Federated Averaging 讓worker 節點做大量計算,以犧牲計算量為代價換取更小的通信次數。
- 相同次數的epochs,梯度下降收斂的會更快。
- FedAvg 來作聯邦學習,數據不需要獨立同分布。
4.2 Privacy
聯邦學習中,用戶的數據始終沒有離開用戶,那么數據是否安全呢?
實際上算梯度的過程就是對數據的一個變換,將數據映射到梯度。
雖然數據沒有發出去,但是梯度是幾乎包含數據所有信息的,所以一定程度上,可以通過梯度反推出數據。如將梯度作為輸入特征,然后學習一個分類器,其根本原理就是梯度帶有用戶信息。當前主流抵御這種攻擊的辦法往梯度加噪聲,但這容易帶來模型不准確,准確率降低等問題。
4.3 Adversarial Robustness
第三個研究熱點讓聯邦學習可以抵御拜占庭錯誤和惡意攻擊。簡單說就是 worker 中出了叛徒,如何學到更好地模型。
- Attack 1 :將部分測試集數據進行修改,使這部分數據成為“毒葯”;
- Attack 2 :將本地數據的標簽換成錯的,用正確的圖片和錯誤的標簽來計算梯度方向;
- Defense 1 :server 用某個 worker 傳來的梯度來更新參數,再測試用了該參數的模型准確率;(效果一般)
- Defense 2:Server 比較各 worker 傳過來的梯度,檢驗是否存在差異很大的梯度;假設了數據獨立同分布,但實際上聯邦學習場景的數據往往不是獨立同分布的,效果一般。
- Defense 3: Server 對各 worker 傳來的梯度加權平均,如中位數。同樣假設了數據獨立同分布。
Shusen Wang 結論:聯邦學習目前沒有良好的抵御方法。
5. 聯邦學習的學習資料
-
聯邦學習生態網站
-
開源項目
- FederatedAI / FATE :https://github.com/FederatedAI/FATE
- Eggroll WeBankFinTech/eggroll:https://github.com/WeBankFinTech/Eggroll
-
視頻資料
-
Shusen Wang 《分布式機器學習》
-
並行計算與機器學習(1/3)(中文) Parallel Computing for Machine Learning (Part 1/3):並行計算基礎以及MapReduce
這節課的主要內容:
0:28 Motivation:並行計算有什么用?為什么機器學習的人需要懂並行計算。
2:42 最小二乘回歸(如果已經懂最小二乘,建議跳到 6:43 )。
6:43 用並行計算來解最小二乘回歸。
11:14 並行計算中的通信問題。
13:10 MapReduce,已經如何用MapReduce實現並行梯度下降,以及通信、同步的問題 -
並行計算與機器學習(2/3)(中文) Parallel Computing for Machine Learning (Part 2/3):參數服務器、去中心化
這節課的主要內容:
1:02 Parameter Server(參數服務器),以及用Parameter Server實現異步梯度下降。
8:27 Decentralized Network(去中心化網絡), 以及用Decentralized Network實現梯度下降。
15:39 Parallel Computing(並行計算)和Distributed Computing(分布式計算)的異同。 -
並行計算與機器學習(3/3)(中文) Parallel Computing for Machine Learning (Part 3/3):Ring All-Reduce
這節課的主要介紹TensorFlow中的並行計算庫、以及其中Ring All-Reduce的原理。
這節課的主要內容:
1:09 如何應用TensorFlow的並行計算庫訓練神經網絡
7:41 Ring All-Reduce的技術原理。 -
聯邦學習:技術角度的講解(中文)Introduction to Federated Learning
這節課的主要內容:
3:13 分布式機器學習
6:07 聯邦學習和傳統分布式學習的區別
12:46 聯邦學習中的通信問題
15:24 Federated Averaging算法
21:24 聯邦學習中的隱私泄露和隱私保護
27:52 聯邦學習中的安全問題(拜占庭錯誤、data poisoning、model poisoning)
33:00 總結
-
-
Reference
[1] 聯邦學習 Federated Learning: https://zhuanlan.zhihu.com/p/93761403
[2] 詳解聯邦學習Federated Learning:https://zhuanlan.zhihu.com/p/79284686
[3] 綜述:《聯邦學習:概念與應用》:https://zhuanlan.zhihu.com/p/127319831
[4] 分布式機器學習(下)-聯邦學習:https://zhuanlan.zhihu.com/p/114028503
Shusen Wang "並行計算與分布式學習 "、"聯邦學習:技術角度的講解(中文)Introduction to Federated Learning" 系列視頻。