聯邦平均算法(Federated Averaging Algorithm,FedAvg)


設一共有\(K\)個客戶機,

中心服務器初始化模型參數,執行若干輪(round),每輪選取至少1個至多\(K\)個客戶機參與訓練,接下來每個被選中的客戶機同時在自己的本地根據服務器下發的本輪(\(t\)輪)模型\(w_t\)用自己的數據訓練自己的模型\(w^k_{t+1}\),上傳回服務器。服務器將收集來的各客戶機的模型根據各方樣本數量用加權平均的方式進行聚合,得到下一輪的模型\(w_{t+1}\)

\[\begin{aligned} & \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k為客戶機k上的樣本數量,n為所有被選中客戶機的總樣本數量\\ \end{aligned} \]

【偽代碼】

\[\begin{aligned} & 算法1:Federated\ Averaging算法(FedAvg)。 \\ & K個客戶端編號為k;B,E,\eta分別代表本地的minibatch\ size,epochs,學習率learning\ rate \\ & \\ & 服務器執行:\\ & \quad 初始化w_0 \\ & \quad for \ 每輪t=1,2,...,do \\ & \qquad m \leftarrow max(C \cdot K,1) \qquad\qquad //C為比例系數 \\ & \qquad S_t \leftarrow (隨機選取m個客戶端) \\ & \qquad for \ 每個客戶端k \in S_t 同時\ do \\ & \qquad \qquad w^k_{t+1} \leftarrow 客戶端更新(k,w_t) \\ & \qquad w_{t+1} \leftarrow \sum^K_{k=1} \frac{n_k}{n} w^k_{t+1} \qquad\qquad //n_k為客戶機k上的樣本數量,n為所有被選中客戶機的總樣本數量\\ & \\ & 客戶端更新(k,w): \qquad \triangleright 在客戶端k上運行 \\ & \quad \beta \leftarrow (將P_k分成若干大小為B的batch) \qquad\qquad //P_k為客戶機k上數據點的索引集,P_k大小為n_k \\ & \quad for\ 每個本地的epoch\ i(1\sim E) \ do \\ & \qquad for\ batch\ b \in \beta \ do \\ & \qquad \qquad w \leftarrow w-\eta \triangledown l(w;b) \qquad\qquad //\triangledown 為計算梯度,l(w;b)為損失函數\\ & \quad 返回w給服務器 \end{aligned} \]

為了增加客戶機計算量,可以在中心服務器做聚合(加權平均)操作前在每個客戶機上多迭代更新幾次。計算量由三個參數決定:

  • \(C\),每一輪(round)參與計算的客戶機比例。
  • \(E(epochs)\),每一輪每個客戶機投入其全部本地數據訓練一遍的次數。
  • \(B(batch size)\),用於客戶機更新的batch大小。\(B=\infty\)表示batch為全部樣本,此時就是full-batch梯度下降了。

\(E=1\ B=\infty\)時,對應的就是FedSGD,即每一輪客戶機一次性將所有本地數據投入訓練,更新模型參數。

對於一個有着\(n_k\)個本地樣本的客戶機\(k\)來說,每輪的本地更新次數為\(u_k=E\cdot \frac{n_k}{B}\)

參考文獻:

  1. H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. Y. Arcas, “Communication-efficient learning of deep networks from decentralized data,” in Proc. AISTATS, 2016, pp. 1273–1282.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM