目錄
3.1 離散測度 (Discrete measures)
3.2 蒙日(Monge)問題
3.3 Kantorovich Relaxation (松弛的蒙日問題)
3.4 Wasserstein距離
3.5 最優運輸問題初解
3.6 熵(Entropic)正則化
3.7 Sinkhorn算法 (NIPS, 2013)
引言
最優運輸(Optimal Transport)近年來引起了廣大學者的研究興趣,並在NIPS和ICML等機器學習頂級會議頻繁出現。然而,最優運輸的基本理論對於初學者來說,並不友好:初看理論,感覺全是晦澀難懂的數學推理公式,讓很多讀者有點望而卻步的感覺。此外,目前國內關於最優運輸理論的研究還比較初步,相關中文資料也比較匱乏。因此,筆者對自己最近幾天在網上博客、論文和視頻等資料的學習過程進行了初步整理,希望對后續的初學者提供一點幫助。
我的入門體驗:最優運輸相關理論的代碼庫已比較豐富(需要代碼,可以去github上搜索,或者檢索關於最優運輸的熱門頂會論文,基本都有開源代碼),並且核心理論也沒那么復雜,或者說只要你認真閱讀完本文,我相信你應該能夠較順暢地把最優運輸理論應用到你的實際應用中了。
1 背景
最優運輸問題最早是由法國數學家加斯帕德·蒙日(Gaspard Monge)在19世紀中期提出,它是一種將給定質量的泥土運輸到給定洞里的最小成本解決方案。這個問題在20世紀中期重新出現在坎托羅維奇的著作中,並在近些年的研究中發現了一些令人驚訝的新進展,比如Sinkhorn算法。最優運輸被廣泛應用於多個領域,包括計算流體力學,多幅圖像之間的顏色轉移或圖像處理背景下的變形,計算機圖形學中的插值方案,以及經濟學、通過匹配和均衡問題等。此外,最優傳輸最近也引起了生物醫學相關學者的關注,並被廣泛用於單細胞RNA發育過程中指導分化以及提高細胞觀測數據的數據增強工具,從而提高各種下游分任務的准確性和穩定性。
當前,許多現代統計和機器學習問題可以被重新描述為在兩個概率分布之間尋找最優運輸圖。例如,領域適應旨在從源數據分布中學習一個訓練良好的模型,並將該模型轉換為采用目標數據分布。另一個例子是深度生成模型,其目標是將一個固定的分布,例如標准高斯或均勻分布,映射到真實樣本的潛在總體分布。在最近幾十年里,OT方法在現代數據科學應用的顯著增殖中重新煥發了活力,包括機器學習、統計和計算機視覺。
2 什么是最優運輸?
參考資料[1]的見解:就是把A數據遷移到B。你可以理解成兩堆土,從A土鏟到另外一個地方,最終堆成B土。就像是以前初中學的線性規划一樣的:3個城市(A, B, C)有1, 0.5, 1.5噸煤,然后要運到2個其他城市,這兩個城市(C, D)分別需要2,1噸煤。然后,不同城市到不同的費用不同,讓你算最優運輸方案和代價。
Jingyi Zhang等人[2]的綜述表明:假設一個操作者經營着n個倉庫和m個工廠。每個倉庫都包含一定數量的有價值的原材料(即工廠正常運行所需要的資源),而且每個工廠對原材料都有一定的需求。假設n個倉庫中資源的總量等於m個工廠對原材料的總需求,運營商的目標是將所有的資源從倉庫轉移到工廠,從而能夠成功地滿足工廠的所有需求,並且總運輸成本盡可能小。具體如下圖1所示的資源分配問題。
通過上述兩個簡單的描述,相信大家應該知道什么是最優運輸了。在這里,你可能對於該理論還是有一點模糊的感覺。但是,請保持這顆好奇心往下看看具體的基本理論,下文將會結合參考文獻[3](關於最優傳輸的一本很經典的開源書籍)提供的公式和圖例來詳細介紹。
3基本概念
3.1離散測度 (Discrete measures)
首先,說一下概率向量(或者稱為直方圖,英文:Histograms, probability vector)的定義:
上述公式的含義:一個長度為n的數組,每個元素的值在[0, 1]之間,並且該數組的和為1,即表示的是一個概率分布向量。
離散測度:所謂測度就是一個函數,把一個集合中的一些子集(符合上述概率分布向量)對應給一個數[4]。具體公式定義如下:
上述公式含義:以a_i為概率和對應位置x_i 的狄拉克δ函數值乘積的累加和。下圖很好地闡述了一組不同元素點的概率向量分布:
上圖中紅色點是均勻的概率分布,藍色點是任意的概率分布。點狀分布對應是一維數據的概率向量分布,而點雲狀分布對應的是二維數據的概率向量分布。
3.2蒙日(Monge)問題
蒙日(Monge)問題的定義:找出從一個 measure到另一個measure的映射,使得所有c ( x i , y j )的和最小,其中c表示映射路線的運輸代價,需要根據具體應用定義。蒙日問題具體的定義公式如下:
對於上述公式的解釋可以采用離散測度來解釋,對於兩個離散測度:
找到一個n維映射到m維的一個映射,使得
上述映射的示意圖如下:
對於上述的映射公式,結合蒙日問題的定義公式,可以歸納如下:
上述公式的含義:通過這個映射T(X_i)的轉移,使得轉移到b_j的所有a_i的值的和剛好等於b_j(其中要求,所有a_i必須轉走,而所有b_j必須收到預期的貨物),即我需要多少就給運輸轉移多少,不能多也不能少。其中c()表示運輸代價,T(x_i)表示映射的運輸方案。
3.3 Kantorovich Relaxation (松弛的蒙日問題)
蒙日問題是最優運輸的起初最重要的思想,然而其有一個很大的缺點: 從a的所有貨物運輸到b時,只能采用原始的貨物大小進行運算,即不能對原始的貨物進行拆開發送到不同目的地。而Kantorovich Relaxation則對蒙日問題進行了松弛處理,即原始的貨物可以分開發送到不同目的地,也可以把蒙日問題理解為Kantorovich Relaxation的一個映射運輸特例。具體區別可以參考下圖[2]。
符合Kantorovich Relaxation的映射運輸定義公式如下:
區別於蒙日問題要求映射運輸的所有a_i一對一轉走到b_j。Kantorovich Relaxation只要求,所有每個a_i中獲取能夠完全轉走,可以是只轉給一個b_j,也可以是多個b_j,但是要確保每個b_j只需要收取預期要求的貨物即可。簡單地描述:行求和對應向量a, 列向量求和對應向量 b.具體的傳輸示例如下:
最后,Kantorovich Relaxation的最優運輸求解公式定義如下:
其中P表示符合所有行求和為向量a,所有列求和為向量b的一個映射。Pi,j表示從第i行映射到第j行的元素值,Ci,j表示完成Pi,j元素映射(或者說是運輸)的運輸代價。
3.4 Wasserstein距離
距離度量是機器學習任務中最重要的一環。比如,常見的人工神經網絡的均方誤差損失函數采用的就是熟知的歐式距離。然而,在最優運輸過程中,優於不同兩點之間均對應不同的概率,如果直接采用歐式距離來計算運輸的損失(或者說對運輸的過程進行度量和評估),則會導致最終的評估結果出現較大的偏差(即忽略了原始不同點直接的概率向量定義)。
針對上述問題,為了對最優運輸選擇的映射路徑好壞進行評估,Wasserstein距離應運而生,其公式和相關引理定義如下:
此處的距離計算公式看起來比較復雜,但是實際上該方法已有代碼庫[5]封裝好,只需要把對應的向量a和其包含的概率分布,以及向量b和其包含的概率分布輸入到封裝好的函數中,即可得到最終的Wasserstein距離。關於此處的介紹和理解,建議參考文末的參考資料[6],其部分解釋可以見下圖:
3.5最優運輸問題初解
最優運輸問題的解是一般是求取Kantorovich Relaxation的解,其可以采用線性規划的標准型來定義和實現求解[7]。了解過線性規划理論知識的同學應該清楚:線性規划求解一般是取可行域內的頂點值,才是最終需求的最優解(最小值或者最大值,具體選取看實際的可行域約束條件)。因此,線性規划的最優解,只可能是可行域表示的可行多面體的一個頂點。
依據參考資料[7]和文獻[3],對於線性規划尋找頂點解時,當判斷其是否是一個最優解,需要符合以下條件:如果P是一個頂點解,那么P中有質量流的路徑一定不行成一個環。這同時也意味着P中最多只能有n + m − 1條不為零的質量流。具體的示意圖如下:
上圖每條連線表示一個質量流,但其其中存在環,可知一定不是最優解。這便是求解最優運輸的幾何解釋,更多的示例,請參考文獻[3]的3.4節,第43頁。
依據上述的思路,采用線性規划求取最優運輸的最優解,有學者提出了采用西北角算法,其可以在經過n + m步計算后,搜索出一個U ( a , b )的頂點。更多具體的解釋和介紹請見參考資料[8]。
然而采用西北角算法每次的檢索結果只有一個頂點,該頂點並不一定是最優解。為了解決該問題,網絡單純形法出現了,該方法通過從可行多面體的一個頂點出發,每一步都到達一個離最優更接近的頂點,逐步達到最優。然而,單純形的最差復雜度是指數級的,不過它的平均復雜度卻非常高效,一般在多項式時間內找到最優解。關於單純形法的具體更多介紹,請見參考資料[9]。下方給出一個單純形法迭代求取最優解的示例圖,具體的解釋可以參考文獻[3]的3.5.3節。
3.6 熵(Entropic)正則化
在大部分應用情況下,求標准Kantorovich Relaxation解是不必要的:如果我們利用正則化,改求近似解,那么最優傳輸的計算代價就大幅降低了[10]。熵正則化的定義公式如下:
對Kantorovich Relaxation解添加正則化后,求解最優傳輸問題的定義如下:
其中參數thegama表示正則化系數,其作用和常用的人工神經網絡等方法中的正則化系數一樣。其中參數P=diag(u)*K*diag(v),其中u和v表示映射的一組解,並參數Pi,j滿足以下定義:
此外,由蒙日問題中定義的所有行的和組成的向量a和所有列的和組成的向量b和求解的映射u,v解的關系如下:
可以推導處如下一個結論,具體的證明推理過程請見參考文獻64頁[3]。
依據參考資料[10]的講解:正則化鼓勵利用多數小流量路徑的傳輸,而懲罰稀疏的,利用少數大流量路徑的傳輸,由此達到減少計算復雜度的目的。具體的解釋可以參考下述的示意圖:
由上圖可知,當參數thegama越大時,最優解的耦合程度變得越加稀疏,即不同解之間距離越大。通過熵正則化的處理,求取近似解的過程,能夠有效降低獲取理想解的時間。
3.7 Sinkhorn算法 (NIPS, 2013)
熵正則化獲取的近似解雖然能夠有效降低算法的時間復雜度,但是其潛力還未被充分挖掘。
Sinkhorn算法基於熵正則化的思想,提供一種更加巧妙的求解向量u和v的解法(得到u和v的解,就可以認為得到了Kantorovich Relaxation問題的對偶解,也就是最終的最優解。此處關於其對偶問題的定義和解釋,可以參考文獻[3]的23頁講解,以及參考資料[11][12])。具體講解請見參考資料[13],部分核心內容如下:
又因為:
結合以上定義,Sinkhorn算法[14]求解u和v的定義公式如下:
上述公式對應的算法偽碼如下[14]:
4 Wasserstein GAN (WGAN) 填補 (ICML, 2017)
此處時一篇2017年的ICML文章,結合了最優運輸中的Wasserstein距離來做填補。其填補的核心總結的描述如下[1],WGAN工作原文請見文獻[15]。
5 最優運輸填補 (ICML, 2020)
接下來最激動人心的時刻來的,即如何采用最優運輸理論進行含缺失數據的填補。原文請見文獻[16]。
原文作者提供的代碼鏈接:https://github.com/BorisMuzellec/MissingDataOT
本文工作最大的亮點:采用最優運輸的Wasserstein距離、.Entropic 正則化以及Sinkhorn算法理論,筆者認為其實首次將其應用到了含缺失數據的填補,並且填補的性能要優於之前已提出的方法。
本文工作的原理:采用熵正則化和Sinkhorn分歧來計算兩個數據分布之間的差異,相關公式如下:
采用最優運輸進行填補假設定義:隨機從原始的數據集中選取兩個bachsize(默認設定為128)大小的數據,該兩組數據的分布應該是接近或者理想上是一樣的分布。那么,先填補,然采用上述截圖中的公式3(Sinkhorn分歧)計算兩個大小為bachsize數據分布之間的相似度,如果填補的值越接近實際值,計算的分布相似度值就越優。
依據這樣的思想,作者提出了算法1,將填補的值當作反向傳播算法待優化更新的參數,采用公式(3)作者損失函數,通過不斷迭代更新梯度,不斷優化填補的值,使得最終填補的結果越接近實際的真實值。具體的算法偽碼如下:
算法1存在一個問題,其填補的值是采用反向傳播算法來更新,導致對於新的數據在沒有輸入到算法1中進行反向傳播更新訓練時,是無法執行填補的。換句話說,就是一個無參數話的填補算法,不具有模型的遷移學習能力。
針對上述問題,作者提出采用一個線性分類器或者MLP分類器對填補的值進行預測,采用公式(3)作為損失函數,對線性分類器或者MLP分類器的權重參數采用反向傳播算法進行更新。最終學習的線性分類器就是一個帶參數的填補模型,能夠對新的含缺失數據執行填補,文章中也稱其為參數化填補算法,具體的算法偽碼如下:
上述的算法2是直接對原始數據進行填補,也就是說無論其一條數據中有多少列屬性含缺失,都是只執行一次填補,這樣會導致無法解耦不同列之間的填補(PS:筆者認為此處提這個原因,可能是直接一次填補的效果不好,不過這個在文章沒有進行實驗對比分析)。針對上述問題,如果數據集含有n列,則采用n-1列數據對當前的1列數據執行填補,具體的算法偽碼如下所示:
然而,筆者通過文章實驗結果以及實際實驗發現,算法3在時間復雜度上很高。即如果原始數據集的維度很高,那么采用算法3執行填補將會耗費很多時間,導致其實用性不高,對於這個問題原文作者沒提。不過對於高維含缺失數據的填補,應該也是屬於另外一個研究問題,比如可能需要先執行降維,然后對降維后的數據進行填補。
參考資料
[1] “最優傳輸之淺談_Hungryof的專欄-CSDN博客.” https://blog.csdn.net/Hungryof/article/details/110549879 (accessed Mar. 11, 2021).
[2] J. Zhang, W. Zhong, and P. Ma, “A Review on Modern Computational Optimal Transport Methods with Applications in Biomedical Research,” arXiv:2008.02995 [cs, stat], Sep. 2020, Accessed: Mar. 11, 2021. [Online]. Available: http://arxiv.org/abs/2008.02995.
[3] G. Peyré and M. Cuturi, “Computational Optimal Transport,” arXiv:1803.00567 [stat], Mar. 2020, Accessed: Mar. 11, 2021. [Online]. Available: http://arxiv.org/abs/1803.00567.
[4] “最優傳輸系列-第一篇_Grant Tour of Algorithms-CSDN博客.” https://blog.csdn.net/Utterly_Bonkers/article/details/88387081?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-4&spm=1001.2101.3001.4242 (accessed Mar. 11, 2021).
[5] “scipy.stats.wasserstein_distance — SciPy v1.6.1 Reference Guide.” https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wasserstein_distance.html (accessed Mar. 11, 2021).
[6] “最優傳輸系列-第二篇_Grant Tour of Algorithms-CSDN博客.” https://blog.csdn.net/Utterly_Bonkers/article/details/88613536?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.control&dist_request_id=1328627.20164.16154281942103719&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.control (accessed Mar. 11, 2021).
[7] “最優傳輸系列-第四篇(3.1-3.2)_Grant Tour of Algorithms-CSDN博客.” https://blog.csdn.net/Utterly_Bonkers/article/details/88758099?spm=1001.2014.3001.5501 (accessed Mar. 12, 2021).
[8] “最優傳輸系列-第六篇(3.4.2)_Grant Tour of Algorithms-CSDN博客.” https://blog.csdn.net/Utterly_Bonkers/article/details/89009325?spm=1001.2014.3001.5501 (accessed Mar. 12, 2021).
[9] “最優傳輸系列-第七篇(3.5-3.5.2)_Grant Tour of Algorithms-CSDN博客.” https://blog.csdn.net/Utterly_Bonkers/article/details/89325557?spm=1001.2014.3001.5501 (accessed Mar. 12, 2021).
[10] “最優傳輸-熵正則化(第八篇)_Grant Tour of Algorithms-CSDN博客.” https://blog.csdn.net/Utterly_Bonkers/article/details/89546491?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.control&dist_request_id=1328627.22339.16154448009187621&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.control (accessed Mar. 11, 2021).
[11] “凸優化中的對偶(Duality in General Programs)_zbwgycm的博客-CSDN博客.” https://blog.csdn.net/zbwgycm/article/details/104752762 (accessed Mar. 11, 2021).
[12] “最優傳輸系列-第三篇(2.5)_Grant Tour of Algorithms-CSDN博客.” https://blog.csdn.net/Utterly_Bonkers/article/details/88713539?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control&dist_request_id=1328627.20562.16154282276672995&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control (accessed Mar. 12, 2021).
[13] “最優傳輸-Sinkhorn算法(第九篇)_Grant Tour of Algorithms-CSDN博客_sinkhorn.” https://blog.csdn.net/Utterly_Bonkers/article/details/90746259?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control&dist_request_id=&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.control (accessed Mar. 11, 2021).
[14] M. Cuturi, “Sinkhorn Distances: Lightspeed Computation of Optimal Transportation Distances,” arXiv:1306.0895 [stat], Jun. 2013, Accessed: Mar. 11, 2021. [Online]. Available: http://arxiv.org/abs/1306.0895.
[15] M. Arjovsky, S. Chintala, and L. Bottou, “Wasserstein generative adversarial networks,” in Proceedings of the 34th International Conference on Machine Learning - Volume 70, Sydney, NSW, Australia, Aug. 2017, pp. 214–223, Accessed: Mar. 10, 2021. [Online].
[16] B. Muzellec, J. Josse, C. Boyer, and M. Cuturi, “Missing Data Imputation using Optimal Transport,” arXiv:2002.03860 [cs, stat], Jul. 2020, Accessed: Mar. 12, 2021. [Online]. Available: http://arxiv.org/abs/2002.03860.