http://www.datakit.cn/blog/2017/02/05/t_sne_full.html
t-SNE(t-distributed stochastic neighbor embedding)是用於降維的一種機器學習算法,是由 Laurens van der Maaten 和 Geoffrey Hinton在08年提出來。此外,t-SNE 是一種非線性降維算法,非常適用於高維數據降維到2維或者3維,進行可視化。
t-SNE是由SNE(Stochastic Neighbor Embedding, SNE; Hinton and Roweis, 2002)發展而來。我們先介紹SNE的基本原理,之后再擴展到t-SNE。最后再看一下t-SNE的實現以及一些優化。
目錄
1.SNE
1.1基本原理
SNE是通過仿射(affinitie)變換將數據點映射到概率分布上,主要包括兩個步驟:
- SNE構建一個高維對象之間的概率分布,使得相似的對象有更高的概率被選擇,而不相似的對象有較低的概率被選擇。
- SNE在低維空間里在構建這些點的概率分布,使得這兩個概率分布之間盡可能的相似。
我們看到t-SNE模型是非監督的降維,他跟kmeans等不同,他不能通過訓練得到一些東西之后再用於其它數據(比如kmeans可以通過訓練得到k個點,再用於其它數據集,而t-SNE只能單獨的對數據做操作,也就是說他只有fit_transform,而沒有fit操作)
1.2 SNE原理推導
SNE是先將歐幾里得距離轉換為條件概率來表達點與點之間的相似度。具體來說,給定一個N個高維的數據 x1,...,xNx1,...,xN(注意N不是維度), t-SNE首先是計算概率pijpij,正比於xixi和xjxj之間的相似度(這種概率是我們自主構建的),即:
這里的有一個參數是σiσi,對於不同的點xixi取值不一樣,后續會討論如何設置。此外設置px∣x=0px∣x=0,因為我們關注的是兩兩之間的相似度。
那對於低維度下的yiyi,我們可以指定高斯分布為方差為12√12,因此它們之間的相似度如下:
同樣,設定qi∣i=0qi∣i=0.
如果降維的效果比較好,局部特征保留完整,那么 pi∣j=qi∣jpi∣j=qi∣j, 因此我們優化兩個分布之間的距離-KL散度(Kullback-Leibler divergences),那么目標函數(cost function)如下:
這里的PiPi表示了給定點xixi下,其他所有數據點的條件概率分布。需要注意的是KL散度具有不對稱性,在低維映射中不同的距離對應的懲罰權重是不同的,具體來說:距離較遠的兩個點來表達距離較近的兩個點會產生更大的cost,相反,用較近的兩個點來表達較遠的兩個點產生的cost相對較小(注意:類似於回歸容易受異常值影響,但效果相反)。即用較小的 qj∣i=0.2qj∣i=0.2 來建模較大的 pj∣i=0.8pj∣i=0.8, cost=plog(pq)plog(pq)=1.11,同樣用較大的qj∣i=0.8qj∣i=0.8來建模較大的pj∣i=0.2pj∣i=0.2, cost=-0.277, 因此,SNE會傾向於保留數據中的局部特征。
思考:了解了基本思路之后,你會怎么選擇σσ,固定初始化?
下面我們開始正式的推導SNE。首先不同的點具有不同的σiσi,PiPi的熵(entropy)會隨着σiσi的增加而增加。SNE使用困惑度(perplexity)的概念,用二分搜索的方式來尋找一個最佳的σσ。其中困惑度指:
這里的H(Pi)H(Pi)是PiPi的熵,即:
困惑度可以解釋為一個點附近的有效近鄰點個數。SNE對困惑度的調整比較有魯棒性,通常選擇5-50之間,給定之后,使用二分搜索的方式尋找合適的σσ
那么核心問題是如何求解梯度了,目標函數等價於∑∑−plog(q)∑∑−plog(q)這個式子與softmax非常的類似,我們知道softmax的目標函數是∑−ylogp∑−ylogp,對應的梯度是y−py−p(注:這里的softmax中y表示label,p表示預估值)。 同樣我們可以推導SNE的目標函數中的i在j下的條件概率情況的梯度是2(pi∣j−qi∣j)(yi−yj)2(pi∣j−qi∣j)(yi−yj), 同樣j在i下的條件概率的梯度是2(pj∣i−qj∣i)(yi−yj)2(pj∣i−qj∣i)(yi−yj), 最后得到完整的梯度公式如下:
在初始化中,可以用較小的σσ下的高斯分布來進行初始化。為了加速優化過程和避免陷入局部最優解,梯度中需要使用一個相對較大的動量(momentum)。即參數更新中除了當前的梯度,還要引入之前的梯度累加的指數衰減項,如下:
這里的Y(t)Y(t)表示迭代t次的解,ηη表示學習速率,α(t)α(t)表示迭代t次的動量。
此外,在初始優化的階段,每次迭代中可以引入一些高斯噪聲,之后像模擬退火一樣逐漸減小該噪聲,可以用來避免陷入局部最優解。因此,SNE在選擇高斯噪聲,以及學習速率,什么時候開始衰減,動量選擇等等超參數上,需要跑多次優化才可以。
思考:SNE有哪些不足? 面對SNE的不足,你會做什么改進?
2.t-SNE
盡管SNE提供了很好的可視化方法,但是他很難優化,而且存在”crowding problem”(擁擠問題)。后續中,Hinton等人又提出了t-SNE的方法。與SNE不同,主要如下:
- 使用對稱版的SNE,簡化梯度公式
- 低維空間下,使用t分布替代高斯分布表達兩點之間的相似度
t-SNE在低維空間下使用更重長尾分布的t分布來避免crowding問題和優化問題。在這里,首先介紹一下對稱版的SNE,之后介紹crowding問題,之后再介紹t-SNE。
2.1 Symmetric SNE
優化pi∣jpi∣j和qi∣jqi∣j的KL散度的一種替換思路是,使用聯合概率分布來替換條件概率分布,即P是高維空間里各個點的聯合概率分布,Q是低維空間下的,目標函數為:
這里的piipii,qiiqii為0,我們將這種SNE稱之為symmetric SNE(對稱SNE),因為他假設了對於任意i,pij=pji,qij=qjipij=pji,qij=qji,因此概率分布可以改寫為:
這種表達方式,使得整體簡潔了很多。但是會引入異常值的問題。比如xixi是異常值,那么∣∣xi−xj∣∣2∣∣xi−xj∣∣2會很大,對應的所有的j, pijpij都會很小(之前是僅在xixi下很小),導致低維映射下的yiyi對cost影響很小。
思考: 對於異常值,你會做什么改進?pipi表示什么?
為了解決這個問題,我們將聯合概率分布定義修正為: pij=pi∣j+pj∣i2pij=pi∣j+pj∣i2, 這保證了∑jpij>12n∑jpij>12n, 使得每個點對於cost都會有一定的貢獻。對稱SNE的最大優點是梯度計算變得簡單了,如下:
實驗中,發現對稱SNE能夠產生和SNE一樣好的結果,有時甚至略好一點。
2.2 Crowding問題
擁擠問題就是說各個簇聚集在一起,無法區分。比如有一種情況,高維度數據在降維到10維下,可以有很好的表達,但是降維到兩維后無法得到可信映射,比如降維如10維中有11個點之間兩兩等距離的,在二維下就無法得到可信的映射結果(最多3個點)。 進一步的說明,假設一個以數據點xixi為中心,半徑為r的m維球(三維空間就是球),其體積是按rmrm增長的,假設數據點是在m維球中均勻分布的,我們來看看其他數據點與xixi的距離隨維度增大而產生的變化。
從上圖可以看到,隨着維度的增大,大部分數據點都聚集在m維球的表面附近,與點xixi的距離分布極不均衡。如果直接將這種距離關系保留到低維,就會出現擁擠問題。
怎么解決crowding問題呢?
Cook et al.(2007) 提出一種slight repulsion的方式,在基線概率分布(uniform background)中引入一個較小的混合因子ρρ,這樣qijqij就永遠不會小於2ρn(n−1)2ρn(n−1) (因為一共了n(n-1)個pairs),這樣在高維空間中比較遠的兩個點之間的qijqij總是會比pijpij大一點。這種稱之為UNI-SNE,效果通常比標准的SNE要好。優化UNI-SNE的方法是先讓ρρ為0,使用標准的SNE優化,之后用模擬退火的方法的時候,再慢慢增加ρρ. 直接優化UNI-SNE是不行的(即一開始ρρ不為0),因為距離較遠的兩個點基本是一樣的qijqij(等於基線分布), 即使pijpij很大,一些距離變化很難在qijqij中產生作用。也就是說優化中剛開始距離較遠的兩個聚類點,后續就無法再把他們拉近了。
2.3 t-SNE
對稱SNE實際上在高維度下 另外一種減輕”擁擠問題”的方法:在高維空間下,在高維空間下我們使用高斯分布將距離轉換為概率分布,在低維空間下,我們使用更加偏重長尾分布的方式來將距離轉換為概率分布,使得高維度下中低等的距離在映射后能夠有一個較大的距離。
我們對比一下高斯分布和t分布(如上圖,code見probability/distribution.md), t分布受異常值影響更小,擬合結果更為合理,較好的捕獲了數據的整體特征。
使用了t分布之后的q變化,如下:
此外,t分布是無限多個高斯分布的疊加,計算上不是指數的,會方便很多。優化的梯度如下:
t-sne的有效性,也可以從上圖中看到:橫軸表示距離,縱軸表示相似度, 可以看到,對於較大相似度的點,t分布在低維空間中的距離需要稍小一點;而對於低相似度的點,t分布在低維空間中的距離需要更遠。這恰好滿足了我們的需求,即同一簇內的點(距離較近)聚合的更緊密,不同簇之間的點(距離較遠)更加疏遠。
總結一下,t-SNE的梯度更新有兩大優勢:
- 對於不相似的點,用一個較小的距離會產生較大的梯度來讓這些點排斥開來。
- 這種排斥又不會無限大(梯度中分母),避免不相似的點距離太遠。
2.4 算法過程
算法詳細過程如下:
- Data: X=x1,...,xnX=x1,...,xn
- 計算cost function的參數:困惑度Perp
- 優化參數: 設置迭代次數T, 學習速率ηη, 動量α(t)α(t)
- 目標結果是低維數據表示 YT=y1,...,ynYT=y1,...,yn
- 開始優化
- 計算在給定Perp下的條件概率pj∣ipj∣i(參見上面公式)
- 令 pij=pj∣i+pi∣j2npij=pj∣i+pi∣j2n
- 用 N(0,10−4I)N(0,10−4I) 隨機初始化 Y
- 迭代,從 t = 1 到 T, 做如下操作:
- 計算低維度下的 qijqij(參見上面的公式)
- 計算梯度(參見上面的公式)
- 更新 Yt=Yt−1+ηdCdY+α(t)(Yt−1−Yt−2)Yt=Yt−1+ηdCdY+α(t)(Yt−1−Yt−2)
- 結束
- 結束
優化過程中可以嘗試的兩個trick:
- 提前壓縮(early compression):開始初始化的時候,各個點要離得近一點。這樣小的距離,方便各個聚類中心的移動。可以通過引入L2正則項(距離的平方和)來實現。
- 提前誇大(early exaggeration):在開始優化階段,pijpij乘以一個大於1的數進行擴大,來避免因為qijqij太小導致優化太慢的問題。比如前50次迭代,pijpij乘以4
優化的過程動態圖如下:
2.5 不足
主要不足有四個:
- 主要用於可視化,很難用於其他目的。比如測試集合降維,因為他沒有顯式的預估部分,不能在測試集合直接降維;比如降維到10維,因為t分布偏重長尾,1個自由度的t分布很難保存好局部特征,可能需要設置成更高的自由度。
- t-SNE傾向於保存局部特征,對於本征維數(intrinsic dimensionality)本身就很高的數據集,是不可能完整的映射到2-3維的空間
- t-SNE沒有唯一最優解,且沒有預估部分。如果想要做預估,可以考慮降維之后,再構建一個回歸方程之類的模型去做。但是要注意,t-sne中距離本身是沒有意義,都是概率分布問題。
- 訓練太慢。有很多基於樹的算法在t-sne上做一些改進
3.變種
后續有機會補充。
- multiple maps of t-SNE
- parametric t-SNE
- Visualizing Large-scale and High-dimensional Data
4.參考文檔
- Maaten, L., & Hinton, G. (2008). Visualizing data using t-SNE. Journal of Machine Learning Research.
5. 代碼
文中的插圖繪制:
# coding:utf-8 import numpy as np from numpy.linalg import norm from matplotlib import pyplot as plt plt.style.use('ggplot') def sne_crowding(): npoints = 1000 # 抽取1000個m維球內均勻分布的點 plt.figure(figsize=(20, 5)) for i, m in enumerate((2, 3, 5, 8)): # 這里模擬m維球中的均勻分布用到了拒絕采樣, # 即先生成m維立方中的均勻分布,再剔除m維球外部的點 accepts = [] while len(accepts) < 1000: points = np.random.rand(500, m) accepts.extend([d for d in norm(points, axis=1) if d <= 1.0]) # 拒絕采樣 accepts = accepts[:npoints] ax = plt.subplot(1, 4, i+1) if i == 0: ax.set_ylabel('count') if i == 2: ax.set_xlabel('distance') ax.hist(accepts, bins=np.linspace(0., 1., 50)) ax.set_title('m=%s' %m) plt.savefig("./images/sne_crowding.png") x = np.linspace(0, 4, 100) ta = 1 / (1 + np.square(x)) tb = np.sum(ta) - 1 qa = np.exp(-np.square(x)) qb = np.sum(qa) - 1 def sne_norm_t_dist_cost(): plt.figure(figsize