分布式多任務學習:論文總結歸納和展望


1 論文總結歸納

做為最后一篇分布式多任務學習的論文閱讀記錄,我決定對我目前為止粗讀和精讀的論文進行一次總結,然后陳述一些個人對該研究領域的見解和想法。

目前已經有許多論文對多任務學習提出了分布式並行方案。在分布式多任務學習中,傳統的處理方式[1][2][3]仍然是基於主從(Master-Workers)架構:多個任務節點(Workers)分攤任務,然后將信息交給主節點(Master)匯總(比如在分布式近端映射算法中,任務節點進行梯度計算,主節點負責近端映射)。

但中心化模式自身存在弊端。首先,當網絡傳輸代價比較大時,容易在中心節點處形成瓶頸;其次,中心化方法對系統的穩定性要求高,因為它要求中心節點能夠穩定地聚合和分發模型,一旦中心節點出錯,整個任務必然失敗;最后,在某些現實場景中,任務節點只存在和其相鄰節點的網絡連接,這時中心化算法將無法工作。

基於以上中心化分布式學習算法的弊端,分布式多任務學習也越來越朝着去中心化的路線發展[4][5][6][7],也就是強調任務節點只能和其鄰接點相互通信,而不存在存放全局模型的一個主節點。加之近年來聯邦學習的發展,該領域往往會關注任務節點的數據隱私性,以及一些分布式計算中的經典問題,比如拜占庭容錯等。此外,由於在現實應用場景中任務之間的關系是未知的(並非做為一個先驗),而且出於對流數據的使用、低計算開銷的需求和對變化環境更強的適應性的考慮,常常需要引入在線學習。

因此分布式多任務學習也常常和在線學習結合。在線多任務學習中任務之間的鄰接權重矩陣可以表示任務之間的關系,而該權重也可以實時跟隨任務之間的參數距離或梯度變化程度來動態調整。

注1: 分布式多任務學習和聯邦學習的區別和聯系


我們前面提到,分布式多任務學習朝着聯邦學習的路線發展。但讀者其實聯邦學習和多任務學習原本是很不一樣的。在標准的聯邦學習中,每個節點任務不共享數據,但是可以共享參數,以此聯合訓練出各一個全局的模型。也就是說,聯邦學習下每個節點的任務是一樣的。而多任務學習是要針對不同的任務協同訓練出多個不同的模型。

但是,為什么分布式多任務學習會走向聯邦學習呢?其實,不是分布式多任務學習選擇了聯邦學習,而是聯邦學習選擇了分布式多任務學習。 原來,聯邦學習由於數據不獨立同分布,每個模型訓練出的局部模型差異會很大,就會使得構建一個全局的、通用的模型難度很大。比如同樣一個下一個單詞預測的任務,同樣給定"I love eating,",但對於下一個單詞每個client會給出不同的答案,這也是現在有人提出聯邦多任務學習的原因)。

為了解決聯邦學習中數據不獨立同分布的的問題,有論文[11][12]提出不求訓練出一個全局的模型,使每個節點訓練各不相同的模型這樣一種訓練方式,這就被冠名為聯邦多任務學習了。


注2: 分布式多任務學習和聯邦多任務學習的區別和聯系


此二者非常相似,但是聯邦多任務學習可以看做是分布式多任務學習在特殊條件下的限制版,即聯邦多任務學習中可能更關注節點的容錯性,以及節點數據集隱私(節點之間的數據不能共享),單純的分布式多任務學習一般沒這幾個需求。此外還有一點就是,按照最初的傳統聯邦多任務學習一般是有中心節點的(如論文[34]中所說),而分布式多任務學習是可以去中心化的(如論文[10]中所說)。但是也有論文把聯邦多任務學習也去中心化了([35]),所以這個應該算不上主要依據。


下面我就按照中心化和去中心化這兩個類別對現有分布式並行方式進行歸納總結。

1.1 中心化方法(centralized approach)

1.1.1 基於近端梯度的同步算法

描述 多任務學習優化中首先面臨的問題即目標函數\(F(\bm{\theta}) = f(\bm{\theta}) + g(\bm{\theta})\)中正則項\(g(\bm{\theta})\)的非凸性,而在數值優化里面近端梯度算法[8](包括一階的FISTA(近端梯度算法的一種變種)、SpaRSA和最近提出的二階的PNOPT)可以有效求解這種所謂的復合凸目標函數。而近端梯度算法一般分為\(f(\bm{\theta})\)的梯度計算和近端映射兩個步驟。\(f(\bm{\theta})\)的梯度計算可以很容易地分攤到各個任務節點上(每個任務節點負責計算一個任務對應的梯度);\(g(\bm{\theta})\)的計算不易分解,可以由主節點完成,即待所有任務節點運算完畢后,將梯度傳往主節點后匯集,並在此基礎上對權重矩陣進行近端映射,以得到更新后的參數矩陣。最后,主節點又將更新了的參數向量分發給各任務節點,開始下一輪迭代。

優點 算法邏輯簡單清晰且通用性強,幾乎可以推廣到所有基於正則化的多任務學習。

缺陷 因為是基於同步通信的優化算法,如果有節點傳輸帶寬過低(或直接down)掉,就會拖累整個系統的運行,導致不能容忍的運行時間和運算資源的浪費。

同步迭代框架

1.1.2 基於近端梯度的異步算法

描述 中心節點只要收到了來自一個任務節點的已經算好的梯度,就會馬上對模型的參數矩陣進行更新,而不用等待其他任務節點完成計算[1]。特別地,在論文[1]中采用了一種前向-后向算子分裂的視角[9][10]來求解目標函數(其實就是近端梯度法,但采用算子的視角可以抽象為不動點迭代,方便我們后面結合KM算法),且最終的迭代方法采用論文[11]中討論的經過異步坐標更新改造的KM迭代方法。

優點 因為通信是異步的,可以提高系統的吞吐率。

缺陷 如果對內存的讀取不加鎖的話會導致不一致性(inconsistency)問題,可能降低算法整體的收斂速率。

異步更新示意圖

1.1.3 基於分解代理損失函數的算法

描述 對於形如\(F(\bm{\theta}) = f(\bm{\theta}) + g(\bm{\theta})\)的復合目標函數,FISTA算法[12]采取的策略是在迭代過程中不斷構建代理損失函數\(Q_{\mathcal{L}}(\bm{\theta}, \hat{\bm{\theta}})\),然后通過優化該代理損失函數來更新參數。后面我們發現,對該代理損失函數的求解進行並行化較為容易。在論文《Parallel Multi-Task Learning》[13]中則更進一步,首先將原始的\(F(\bm{\theta}) = f(\bm{\theta}) + g(\bm{\theta})\)問題先轉換為了對偶問題,然后用FISTA算法對對偶問題進行迭代求解,在求解的過程中構建代理損失函數,然后對代理損失函數進行並行化。

優點 對代理損失函數可以按照固定的套路進行分解,從而使該方法具有較強的通用性。

缺陷 如果代理損失函數構建不恰當,則可能導致其做為原損失函數的上界過松,降低優化效果。而且論文只是將原問題分解為了子問題,然而子問題的求解仍然是串行的,可能會花費較多時間。

FISTA算法偽代碼

1.1.4 基於本地去偏估計的算法

描述 基於近端梯度的常規優化算法會帶來較大的通信開銷,但如果完全不通信就會退化為單任務學習。為了解決這個難點,論文《distributed multitask learning》[14]提出了基於去偏lasso的分布式算法。該算法適用於解決形如\(F(\bm{\theta}) = f(\bm{\theta}) + \text{pen}(\bm{\theta})\)的問題,其中\(\text{pen}(\bm{\theta})\)是group sparse 正則項。該算法介於常規的通信算法和不通信的算法之間,只需要一輪通信,但仍然保證了使用group regularization所帶來的統計學效益。

優點 直接在保證多任務學習特性的條件下,大大減少了分布式計算所帶來的的通信次數,是所有方法中加速效果最好的。

缺點 僅局限於基於group sparse正則項的多任務學習,難以進一步推廣。

去偏lasso算法

1.2 去中心化方法(decentralized approach)

1.2.1 用任務信息共享解決在線多任務學習的帕累托最優問題

描述 去中心化的在線多任務學習會面臨帕累托(Pareto optimality)最優問題(即每個任務的最優解不同,最終只能折中達到一個全局最優,而這個全局最優會犧牲每個任務的精度),Zhang C[4]等人提出了一個基於任務信息共享的去中心化優化方法。在該方法中,每個任務節點都有機會負責所有的任務(存儲有所有任務的權重)。每輪迭代之前每個節點會和相鄰的節點做權重組合(去中心化方法中的一個常見操作,用於收斂到一個全局最小),然后每個任務隨機接收另一個任務的數據,再根據此數據更新本地對應部分的權重。

優點 打破並行多任務學習中的常規,創新性地提出了任務間的數據共享,有利於收斂到全局最優。

缺陷 如果是在智能手機、無人駕駛等聯邦學習的環境下,任務節點之間的數據共享將不可行,使該方法存在一定的局限性。而且,該方法假設任務節點之間的關聯度是靜態且相同的(體現為鄰接矩陣的權重取值),而現實場景下任務節點之間的關聯度是不同甚至是動態的,我們可以考慮根據工作節點參數之間的距離或工作節點梯度變化程度來動態調整。

同步迭代框架

1.2.2 周期性同步的去中心化優化方法

描述 Yang P等人[5]設計了一個去中心化的原始-對偶優化方法來求解在線多任務學習問題。在該方法中采用了一種周期性同步的措施(周期為\(\tau\)),這種周期化同步的方法攤銷了同步操作所帶來的的(等待低速節點)的時間延遲。在論文中,鄰接權重矩陣\(\mathbf{S}\)固定為一個稀疏矩陣(當\(t \text{ mod } \tau =0\)時,\(\mathbf{S}_t = \mathbf{S}\);其余情況鄰接矩陣\(\mathbf{S}_t = \mathbf{I}_{m\times m}\),意味着節點間沒有同步與通信)。

優點 依靠周期性的優化算法降低了同步操作帶來的(等待低速節點)的開銷。

缺陷 本質上鄰接矩陣的權重仍然沒有根據工作節點的情況進行變化,

去中心化的迭代算法

1.2.3 分布式在線多任務學習中的拜占庭容錯問題

描述 Li J等人[6]設計了一個對任意數量的拜占庭攻擊者具有容錯性的任務節點之間的權重更新算法。在分布式在線多任務學習的環境下,任務節點之間的權重可以表示任務之間的關系,該權重可以動態地進行學習和調整。一種常見的調整方法為根據不同任務節點參數之間的距離進行調整。然而,這種調整方法受到拜占庭攻擊者的影響。論文設計了一種特別的權重調整方法,該方法能夠過濾掉來自鄰接點中的拜占庭攻擊者的信息,而只使用剩余鄰接點的信息。該算法對於任意數量的拜占庭攻擊者都適用。

優點 首先,該論文基於在線學習方法學習任務之間的關系(鄰接矩陣權重可調整),具有一定的優越性。其次,論文將拜占庭容錯性引入多任務學習研究,帶來了一個不一樣的視角;最后,每輪迭代只需要花費關於鄰接點數量和任務維度的線性時間復雜度的時間開計算權重更新,具有計算高效性。

缺陷 每輪迭代鄰接矩陣的權重更新需要解一個最優化問題,而該問題需要同步操作,對於低速的節點來說會到來較高的時間延遲(尤其當任務節點數量)較多時。

去中心化的迭代算法

1.2.4 基於混合分布的聯邦多任務學習模型

描述 Marfoq O等人[7]設計了一個基於混合數據分布的聯邦(分布式)多任務模型。該模型假設每個任務節點的數據是不能移動且並不是同分布的(這也是聯邦學習中經常面對的情況)。作者認為多任務學習正好可以用於解決聯邦學習中這種節點數據不是同分布的情況。該篇論文假設每個論文節點是\(M\)個概率分布的混合分布(對任務\(t\)的數據集,第\(j\)個概率分布對應的權重為\(\pi_{tj}^*\)),這樣的優點在於每個任務節點可以與其他任務節點共享知識。此外論文也假設每個任務節點的權重是\(M\)個權重參數的混合分布,還可以拓展到聚類多任務學習的情況(聚類多任務學習的一種情況[19]就是強調每個任務的參數分布是多個類簇分布的混合分布,每個類簇共享一個模型。應用在這里如果任務\(t\)在簇\(c\)中,則\(\pi_{tc}^* = 0\))。最后,作者給出了該算法的中心化和去中心化實現形式。其中去中心化實現中,任務節點將會和其鄰接節點交換權重。

優點 以多任務學習之矛攻聯邦學習之盾,針對聯邦學習數據分布不一致問題給出了較好的解決方案。

缺陷 聯邦學習的情境下增加了數據不能轉移的限制,一定程度上也導致了多任務可能無法收斂到全局最優。而且每輪迭代每個任務都需要同步參數交互操作,如果存在通信帶寬較小的節點將會導致較大的時延。此外,通過任務參數的混合分布,較好地和聚類多任務學習聯系起來。

去中心化的迭代算法

2 個人看法

根據目前我所閱讀的論文情況,⽬前的分布式多任務學習的論文大多關注於形如\(F(\bm{\theta})=f(\bm{\theta})+g(\bm{\theta})\)的基於正則化的多任務學習,因為這樣可以使分布式並行的方法更有通用性。目前我也准備集中解決這類問題的並行化。

我們回顧上面提到的所有方法的優缺點,在中心化模型方面,Baytas I M等人[1]提出的基於近端映射的同步和異步算法和Zhang Y[13]提出的基於分解代理損失函數的並行化方法,都着重於針對正則化多任務學習的提供一個通用的並行架構,但[1]會帶來較高的通信量,[13]的並行化程度非常有限。Wang J等人[14]提出的基於去偏lasso的分布式方法盡管大大減少了通信量,但只適用於增強group sparse的正則項(如group lasso等),不能進一步推廣到其他的基於正則化的方法(比如基於聚類/層次化的多任務學習)。可以說,在中心化模型方面,目前以上的工作都尚未對基於聚類/層次化的多任務學習⽅法提供一個加速比高的分布式並⾏手段。

此外,在去中心化模型方面,不同的論文在去中心化的基礎上由不同的角度做了許多工作。Zhang C等人[4]打破了任務間的數據壁壘,為了避免陷入多任務的帕累托最優情形,使每個任務之間可以共享數據,且每個任務節點將有機會負責所有任務。這種方法有悖於數據隱私保護的觀念,在如今聯邦學習的大環境下顯得並不可行。Yang P等人[5]方法創新點一是其使用的原始-對偶優化方法,而是其提出的周期性優化的思想,這種思想大大減小了去中心化環境中任務節點與相鄰節點間的通信等待時延。Li J[6]等人認為鄰接矩陣權重更新過程中的任務之間的參數相似度計算對拜占庭攻擊缺少容錯性,設計了一個拜占庭容錯的鄰接矩陣權重更新算法,這個從分布式系統/聯邦學習出發的視角對我很有啟發。 Marfoq O等人[7]從聯邦學習出發,因各節點的數據分布不同將多任務學習引入到聯邦學習情景中,體現了學科之間的交融性;同時也假設了各節點的權重服從混合分布,很好地和聚類多任務學習聯系起來。

目前我最感興趣的是和聚類多任務學習結合的[7]聯邦學習算法。后面的科研也會在這篇論文的基礎上進行改進,這里就不展開詳細敘述了。

參考

  • [1] Baytas I M, Yan M, Jain A K, et al. Asynchronous multi-task learning[C]//2016 IEEE 16th International Conference on Data Mining (ICDM). IEEE, 2016: 11-20.
  • [2] Liu S, Pan S J, Ho Q. Distributed multi-task relationship learning[C]//Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. 2017: 937-946.
  • [3] Dinuzzo F, Pillonetto G, De Nicolao G. Client–server multitask learning from distributed datasets[J]. IEEE Transactions on Neural Networks, 2010, 22(2): 290-303.
  • [4] Zhang C, Zhao P, Hao S, et al. Distributed multi-task classification: A decentralized online learning approach[J]. Machine Learning, 2018, 107(4): 727-747.
  • [5] Yang P, Li P. Distributed primal-dual optimization for online multi-task learning[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2020, 34(04): 6631-6638.
  • [6] Li J, Abbas W, Koutsoukos X. Byzantine Resilient Distributed Multi-Task Learning[J]. arXiv preprint arXiv:2010.13032, 2020.
  • [7] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.
  • [8] Ji S, Ye J. An accelerated gradient method for trace norm minimization[C]//Proceedings of the 26th annual international conference on machine learning. 2009: 457-464.
  • [9] P. L. Combettes and V. R. Wajs, “Signal recovery by proximal forwardbackward splitting,” Multiscale Modeling & Simulation, vol. 4, no. 4, pp. 1168–1200, 2005.
  • [10] Z. Peng, T. Wu, Y. Xu, M. Yan, and W. Yin, “Coordinate-friendly structures, algorithms and applications,” Annals of Mathematical Sciences and Applications, vol. 1, pp. 57–119, 2016.
  • [11] Z. Peng, Y. Xu, M. Yan, and W. Yin, “ARock: An algorithmic framework for asynchronous parallel coordinate updates,” SIAM Journal on Scientific Computing, vol. 38, no. 5, pp. A2851–A2879, 2016.
  • [12] A. Beck and M. Teboulle, “A fast iterative shrinkagethresholding algorithm for linear inverse problems,” SIAM Journal on Imaging Sciences, 2009
  • [13] Zhang Y. Parallel multi-task learning[C]//2015 IEEE International Conference on Data Mining. IEEE, 2015: 629-638.
  • [14] Wang J, Kolar M, Srerbo N. Distributed multi-task learning[C]//Artificial intelligence and statistics. PMLR, 2016: 751-760.
  • [15] Smith V, Chiang C K, Sanjabi M, et al. Federated multi-task learning[J]. Advances in Neural Information Processing Systems, 2017.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM