聯邦學習FedAvg記錄


Notation

符號 含義
\(F(w)\) 總目標函數
\(w\) 待優化參數
\(F_k(w)\) \(k\)個client的目標函數
\(p_k\) \(k\)個client在總表函數中的占比
\(E\) 每個local update的次數
\(T\) 總迭代次數,即通訊次數為\(T/E\)
\(\eta_t\) \(t\)時刻學習率
\(L\) 函數\(F_k\)\(L-smooth\),即\(\nabla^2 F_k(w)\leq L\)
\(\mu\) 函數\(F_k\)\(\mu\)convex,即\(\nabla^2 F_k(w)\geq u\)
\(\sigma_k\) \(E(\Vert\nabla F_k(w_t^k, \xi^k_t)-\nabla F_k(w_t^k)\Vert)\leq\sigma^2_k\)
\(G\) \(E(\Vert\nabla F_k(w_t^k, \xi^k_t)\Vert)\leq G^2\)
\(F^*\) 最優函數值
\(F_k^*\) 單獨對第\(k\)個client優化得到的最優函數值
\(\Gamma\) \(F^*-\sum_k\,p_kF_k^*\),度量異質性
\(\eta_t\) \(t\)時刻的學習率
\(\kappa\) \(\frac{L}{\mu}\),可近似看為條件數
\(w^*\) 最優參數
\(\xi_t^k\) client \(k\)\(t\)時刻進行隨機梯度下降選出的樣本

假設

  1. 函數\(F_k\)\(L-smooth\),對於所有的\(k\)
  2. 函數\(F_k\)\(\mu-convex\),對於所有的\(k\)
  3. client\(k\)每次計算的隨機梯度的方差是\(\sigma^2_k\)有界的
  4. 所有計算的隨機梯度的范數是\(G\)有界的

全參與下的收斂性證明

引入兩個變量,\(v^k_t\)\(w_t^k\)

\[\begin{align*} v_{t+1}^k &= v_t^k - \eta_t\nabla F_k(w_t^k, \xi_t^K)\\ w_{t+1}^k &= \left\{ \begin{matrix} v_{t+1} \quad &\text{for }t+1 \notin I_E\\ \sum_k^{N}\,p_k v_{k+1}^k \quad &\text{for }t+1 \in I_E\end{matrix} \right. \end{align*} \]

定義

\[\begin{align*} \bar v_t &= \sum\nolimits_k \, p_k v_t^k\\ \bar w_t &= \sum\nolimits_k \, p_k w_t^k\\ \bar g_t &= \sum\nolimits_k\, p_k\nabla F_k(w_t^k)\\ g_t &=\sum\nolimits_k p_k\nabla F_k(w_t^k, \xi_k^t) \end{align*} \]

由於\(t\in I_E\),在能夠交換的迭代輪次才能獲取參數的更新\(w\),變量\(v\)用來表示在不能進行交換數據的輪次的參數。由於全局參與,\(\bar v_t = \bar w_t\)對於所有的\(t\),而且\(\bar v_{t+1} = \bar w_t - \eta_t g_t\)

個人理解:

要證明收斂性需要證明,參數是收斂的,由於參數\(\bar w_t\)是根據梯度下降求出來的,所以需要證明,

\[E\Vert \bar w_{t+1} - w^*\Vert \leq l(\bar w_{t} - w) \]

即當前迭代的參數和最優點的\(w^*\)的距離是小於上一次迭代參數與最優參數的距離,而且\(l\)函數是可以遞推的。也就是說,當前迭代的參數和最優點的\(w^*\)的距離的上界是逐漸減小的。

文章沒有選擇\(\bar w - w^*\)而是選擇了\(\bar v - w^*\),因為\(\bar v\)是對應所有client的,在部分參與的場景下,\(\bar w\)是偏差的。

\[\begin{align*} E\Vert \bar v_{t-1} - w^*\Vert^2 &= E\Vert \bar w_t - \eta_t g_t -w^*-\eta_t \bar g_t + \eta_t \bar g_t\Vert^2\\ &= E\left(\Vert \bar w_t -\eta _t \bar g_t -w^*\Vert ^2 + 2\eta_t <\bar w_t-\eta_t\bar g_t -w^*, -g_t+\bar g_t> + \eta_t^2 \Vert\bar g_t-g_t\Vert^2\right) \end{align*}\label{eq:1} \tag{1} \]

上面之所以拆分出\(-\eta_t \bar g_t + \eta_t \bar g_t\)的原意是想利用\(E(\eta_t g_t - \eta_t \bar g_t)=0\)來對\(\ref{eq:1}\)進行拆分。

對於\(E\left(\Vert \bar w_t -\eta _t \bar g_t -w^*\Vert ^2\right)\)繼續進行計算

\[\begin{align*} \Vert \bar w_t - \eta_t \bar g_t - w^*\Vert^2&= \Vert \bar w_t - w^*\Vert^2 - 2\eta_t<\bar w_t-w^*, \bar g_t> +\eta_t^2\Vert \bar g_t\Vert^2\\ \label{eq:2}\tag{2} \end{align*} \]

根據\(L-smooth\)[1]

\[\begin{align*} \Vert \nabla F_k(w_t^k)\Vert^2 \leq 2L(F_k(w_t^k - F_k^*)) \label{eq:3} \tag{3} \end{align*} \]

因為二范數為凸函數再結合\(~\ref{eq:3}\),得到

\[\begin{align*} \eta^2_t \Vert \bar g_t \Vert^2 &\leq \eta_t^2 \sum \, p_k\Vert \nabla F_k(w_t^k)\Vert^2\\ &\leq 2L\eta_t^2(F_k(w_t^k - F_k^*)) \end{align*} \]

對於\(2\eta_t <\bar w_t -w^*, \bar g_t>\)展開

\[\begin{align*} -2\eta_t <\bar w_t -w^*, \bar g_t>&=-2\eta_t \sum\, p_k <\bar w_t-w^*, \nabla F_k(w^k_t)> \\ &= -2\eta_t \sum\, p_k <\bar w_t - w_t^k, \nabla F_k(w_t^k)>-2\eta_t\sum\, p_k < w_t^k - w_t^k, \nabla F_k(w_t^k)> \end{align*} \]

根據可惜施瓦茨不等式和矩陣不等式得到

\[\begin{align*} -2<\bar w_t-w_t^k, \nabla F_k(w_t^k) > \leq \frac{1}{\eta_t} \Vert \bar w_t - w_t^k\Vert^2+\eta_k \Vert \nabla F_k(w_t^k)\Vert^2 \end{align*} \]

根據\(\mu-convex\)得到

\[\begin{align*} -<w_t^k - w^*, \nabla F_k(w_t^k)> \leq -(F_k(w_t^k)-F_k(w^*)) - \frac{\mu}{2}\Vert w_t^k-w^*\Vert^2 \end{align*} \]

因此\(~\ref{eq:2}\)寫為

\[\begin{align*} \Vert \bar w_t - \eta_t \bar g_t - w^*\Vert^2&= \Vert \bar w_t - w^*\Vert^2 - 2\eta_t<\bar w_t-w^*, \bar g_t> +\eta_t^2\Vert \bar g_t\Vert^2\\ & \leq \Vert \bar w_t - w^*\Vert^2 + 2L\eta_t^2(F_k(w_t^k - F_k^*)) + \\ &\quad\eta_t\sum\, p_k\left(\frac{1}{\eta_t} \Vert \bar w_t - w_t^k\Vert^2+\eta_k \Vert \nabla F_k(w_t^k)\Vert^2\right) -\\ &\quad2\eta_t \sum\, p_k(F_k(w_t^k)-F_k(w^*) - \frac{\mu}{2}\Vert w_t^k-w^*\Vert^2)\\ & \leq(1-\mu\eta_t)\Vert \bar w_t- w^*\Vert^2 + \sum\, p_k \Vert \bar w_t-w^k_t\Vert^2+ \\&\quad4L\eta_t^2 \sum \, p_k(F_k(w_t^k)-F_k^*) - 2\eta_t\sum\, p_k(F_k(w_t^k)-F_k(w^*)) \end{align*} \]

拿出后面的\(4L\eta_t^2 \sum \, p_k(F_k(w_t^k)-F_k^*) - 2\eta_t\sum\, p_k(F_k(w_t^k)-F_k(w^*))\)一項,定義\(\gamma_t=2\eta_t(1-2L\eta_t)\),同時假定\(\eta_t\)隨時間非增(也就是學習率是衰減的)且\(\eta_t\leq \frac{1}{4L}\)。可以得到定義的\(\eta_t\leq\gamma_t\leq 2 \eta_t\)。整理得到

\[\begin{align*} &4L\eta_t^2 \sum \, p_k(F_k(w_t^k)-F_k^*) - 2\eta_t\sum\, p_k(F_k(w_t^k)-F_k(w^*))\\ &=-2\eta_t(1-2L\eta_t)\sum \, p_k(F_k(w_t^k)-F_k^*)+2\eta_t\sum\, p_k(F_k(w^*)-F_k*)\\ &=-\gamma_t \sum \, p_k(F_k(w_t^k)-F^*) + (2\eta_t - \gamma_t)\sum\, p_k(F^*-F_k^*)\\ &=-\gamma_t \sum\, p_k(F_k(w_t^k)-F^*) + 4L\eta_t^2\Gamma \end{align*} \]

第三個等號將\(F_k(w_t^k-F_k^*)\)拆分為\(F_k(w_t^k)-F^*+F^* - F_k^*\)

\[\begin{align*} \sum\, p_k(F_k(w_t^k)-F^*) & = \sum\, p_k (F_k(w_t^k)-F_k(\bar w_t)) + \sum\, p_k(F_k(\bar w_t)-F^*)\\ & \geq \sum\, p_k <\nabla F_k(\bar w_t), w_t^k-\bar w_t> +F_k(\bar w_t)-F^*\\ &\geq -\frac 1 2\sum\, p_k \left[\eta_t\Vert \nabla F_k(\bar w_t)\Vert^2 + \frac{1}{\eta}\Vert w_t^k -\bar w_t\Vert^2\right] + (F(\bar w_t)-F^*)\\ &\geq -\sum\, p_k \left[\eta_t L(F_k(\bar w_t) - F_k^*) + \frac{1}{2\eta}\Vert w_t^k -\bar w_t\Vert^2\right] + (F(\bar w_t)-F^*) \end{align*} \]

綜上所述,

\[\begin{align*} &4L\eta_t^2 \sum \, p_k(F_k(w_t^k)-F_k^*) - 2\eta_t\sum\, p_k(F_k(w_t^k)-F_k(w^*))\\ &\leq \gamma_t \sum\, p_k \left[\eta_t L(F_k(\bar w_t) - F_k^*) +\frac{1}{2\eta}\Vert w_t^k -\bar w_t\Vert^2\right]-\gamma(F(\bar w_t)-F^*)+ + 4L\eta_t\Gamma\\ & = \gamma_t(\eta_tL-1)\sum \, p_k(F_k(\bar w_t) - F^*) + (4L\eta_t^2+\gamma_t\eta_tL)+\frac{\gamma}{2\eta_t}\sum \, p_k \Vert w_t^k - \bar w_t\Vert^2\\ &\leq 6L\eta_t^2\Gamma + \sum\, p_k \Vert w_t^k - \bar w_t\Vert^2 \end{align*} \]

最后一個不等式取得的原因是\((\eta_tL-1)\leq -\frac{3}{4}\)\(\sum \, p_k(F_k(\bar w_t) - F^*)\geq0\)

因此

\[\Vert \bar w_t -\eta _t \bar g_t -w^*\Vert ^2\leq (1-\mu \eta_t)\Vert \bar w_t- w^*\Vert^2 +6L\eta_t^2\Gamma + \sum\, p_k \Vert w_t^k - \bar w_t\Vert^2 \]

加上梯度方差有限假設,即

\[\begin{align*} E\Vert g_t - \bar g_t\Vert^2 &= E\Vert \sum \, p_k(\nabla F_k(w_k, \xi_t^k) - \nabla F_k(w_t^k))\Vert\\ & \leq \sum\, p_k^2 \sigma_k^2 \end{align*} \]

\[\begin{align*}E\sum\, p_k \Vert \bar w_t - w^k_t\Vert^2 &= E\sum \,p_k\Vert w^k_t-\bar w_{t_0}-(\bar w_t - \bar w_{t_0})\Vert^2\\ &\leq E\sum\, p_k \Vert w_t^k - \bar w_{t_0}\Vert^2\\ &\leq \sum p_k E\sum\nolimits_{t=t_0}^{t-1}\, (E-1)\eta^2_t \Vert \nabla F_k(w_t^k, \xi_t^k)\Vert\\ &\leq 4\eta_t^2 (E-1)^2 G^2 \end{align*} \]

因此最終,令\(\Delta_t=E\Vert \bar w_{t+1}-w^*\Vert\)

\[\begin{align*} \Delta_{t+1}\leq(1-\mu\eta_t)\Delta_t + \eta_t^2 B \end{align*} \]

\(B=\sum\, p_k^2 \sigma_k^2+6L\Gamma+8(E-1)^2G^2\)

參考資料

  1. Francis Bach, Statistical machine learning and convex optimization


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM