t-SNE高維數據可視化(python)


t-SNE實踐——sklearn教程

t-SNE是一種集降維與可視化於一體的技術,它是基於SNE可視化的改進,解決了SNE在可視化后樣本分布擁擠、邊界不明顯的特點,是目前最好的降維可視化手段。 
關於t-SNE的歷史和原理詳見從SNE到t-SNE再到LargeVis。 

這里寫圖片描述

 

代碼見下面例一

TSNE的參數

函數參數表:

parameters 描述
n_components 嵌入空間的維度
perpexity 混亂度,表示t-SNE優化過程中考慮鄰近點的多少,默認為30,建議取值在5到50之間
early_exaggeration 表示嵌入空間簇間距的大小,默認為12,該值越大,可視化后的簇間距越大
learning_rate 學習率,表示梯度下降的快慢,默認為200,建議取值在10到1000之間
n_iter 迭代次數,默認為1000,自定義設置時應保證大於250
min_grad_norm 如果梯度小於該值,則停止優化。默認為1e-7
metric 表示向量間距離度量的方式,默認是歐氏距離。如果是precomputed,則輸入X是計算好的距離矩陣。也可以是自定義的距離度量函數。
init 初始化,默認為random。取值為random為隨機初始化,取值為pca為利用PCA進行初始化(常用),取值為numpy數組時必須shape=(n_samples, n_components)
verbose 是否打印優化信息,取值0或1,默認為0=>不打印信息。打印的信息為:近鄰點數量、耗時、σσ、KL散度、誤差等
random_state 隨機數種子,整數或RandomState對象
method 兩種優化方法:barnets_hutexact。第一種耗時O(NlogN),第二種耗時O(N^2)但是誤差小,同時第二種方法不能用於百萬級樣本
angle 當method=barnets_hut時,該參數有用,用於均衡效率與誤差,默認值為0.5,該值越大,效率越高&誤差越大,否則反之。當該值在0.2-0.8之間時,無變化。

返回對象的屬性表:

Atrtributes 描述
embedding_ 嵌入后的向量
kl_divergence_ KL散度
n_iter_ 迭代的輪數

t-distributed Stochastic Neighbor Embedding(t-SNE)

t-SNE可降樣本點間的相似度關系轉化為概率:在原空間(高維空間)中轉化為基於高斯分布的概率;在嵌入空間(二維空間)中轉化為基於t分布的概率。這使得t-SNE不僅可以關注局部(SNE只關注相鄰點之間的相似度映射而忽略了全局之間的相似度映射,使得可視化后的邊界不明顯),還關注全局,使可視化效果更好(簇內不會過於集中,簇間邊界明顯)。

目標函數:原空間與嵌入空間樣本分布之間的KL散度。 
優化算法:梯度下降。 
注意問題:KL散度作目標函數是非凸的,故可能需要多次初始化以防止陷入局部次優解。

t-SNE的缺點:

  • 計算量大,耗時間是PCA的百倍,內存占用大。
  • 專用於可視化,即嵌入空間只能是2維或3維。
  • 需要嘗試不同的初始化點,以防止局部次優解的影響。

t-SNE的優化

在優化t-SNE方面,有很多技巧。下面5個參數會影響t-SNE的可視化效果:

  • perplexity 混亂度。混亂度越高,t-SNE將考慮越多的鄰近點,更關注全局。因此,對於大數據應該使用較高混亂度,較高混亂度也可以幫助t-SNE拜托噪聲的影響。相對而言,該參數對可視化效果影響不大。
  • early exaggeration factor 該值表示你期望的簇間距大小,如果太大的話(大於實際簇的間距),將導致目標函數無法收斂。相對而言,該參數對可視化效果影響較小,默認就行。
  • learning rate 學習率。關鍵參數,根據具體問題調節。
  • maximum number of iterations 迭代次數。迭代次數不能太低,建議1000以上。
  • angle (not used in exact method) 角度。相對而言,該參數對效果影響不大。

PS:一個形象展示t-SNE優化技巧的網站How to Use t-SNE Effectively. 

代碼

例一

  1.  
    import numpy as np
  2.  
    import matplotlib.pyplot as plt
  3.  
    from sklearn import manifold, datasets
  4.  
     
  5.  
    digits = datasets.load_digits(n_class= 6)
  6.  
    X, y = digits.data, digits.target
  7.  
    n_samples, n_features = X.shape
  8.  
     
  9.  
    '''顯示原始數據'''
  10.  
    n = 20 # 每行20個數字,每列20個數字
  11.  
    img = np.zeros(( 10 * n, 10 * n))
  12.  
    for i in range(n):
  13.  
    ix = 10 * i + 1
  14.  
    for j in range(n):
  15.  
    iy = 10 * j + 1
  16.  
    img[ix:ix + 8, iy:iy + 8] = X[i * n + j].reshape((8, 8))
  17.  
    plt.figure(figsize=( 8, 8))
  18.  
    plt.imshow(img, cmap=plt.cm.binary)
  19.  
    plt.xticks([])
  20.  
    plt.yticks([])
  21.  
    plt.show()

 

數字

 

  1.  
    '''t-SNE'''
  2.  
    tsne = manifold.TSNE(n_components=2, init= 'pca', random_state=501)
  3.  
    X_tsne = tsne.fit_transform(X)
  4.  
     
  5.  
    print("Org data dimension is {}.
  6.  
    Embedded data dimension is {}".format(X.shape[-1], X_tsne.shape[-1]))
  7.  
     
  8.  
    '''嵌入空間可視化'''
  9.  
    x_min, x_max = X_tsne.min(0), X_tsne.max(0)
  10.  
    X_norm = (X_tsne - x_min) / (x_max - x_min) # 歸一化
  11.  
    plt.figure(figsize=(8, 8))
  12.  
    for i in range(X_norm.shape[0]):
  13.  
    plt.text(X_norm[i, 0], X_norm[i, 1], str(y[i]), color=plt.cm.Set1(y[i]),
  14.  
    fontdict={ 'weight': 'bold', 'size': 9})
  15.  
    plt.xticks([])
  16.  
    plt.yticks([])
  17.  
    plt.show()

 

可視化結果




t-SNE高維數據可視化(python)

t-SNE(t-distributedstochastic neighbor embedding )是目前最為流行的一種高維數據降維的算法。在大數據的時代,數據不僅越來越大,而且也變得越來越復雜,數據維度的轉化也在驚人的增加,例如,一組圖像的維度就是該圖像的像素個數,其范圍從數千到數百萬。

對計算機而言,處理高維數據絕對沒問題,但是人類能感知的確只有三個維度,因此很有必要將高維數據可視化的展現出來。那么如何將數據集從一個任意維度的降維到二維或三維呢。T-SNE就是一種數據降維的算法,其成立的前提是基於這樣的假設:盡管現實世界中的許多數據集是嵌入在高維空間中,但是都具有很低的內在維度。也就是說高維數據經過降維后,在低維狀態下更能顯示出其本質特性。這就是流行學習的基本思想,也稱為非線性降維。

關於t-SNE的詳細介紹可以參考:https://www.oreilly.com/learning/an-illustrated-introduction-to-the-t-sne-algorithm

下面就展示一下如何使用t-SNE算法可視化sklearn庫中的手寫字體數據集。

 

  1.  
    import numpy as np
  2.  
    import sklearn
  3.  
    from sklearn.manifold import TSNE
  4.  
    from sklearn.datasets import load_digits
  5.  
     
  6.  
    # Random state.
  7.  
    RS = 20150101
  8.  
     
  9.  
    import matplotlib.pyplot as plt
  10.  
    import matplotlib.patheffects as PathEffects
  11.  
    import matplotlib
  12.  
     
  13.  
    # We import seaborn to make nice plots.
  14.  
    import seaborn as sns
  15.  
    sns.set_style('darkgrid')
  16.  
    sns.set_palette('muted')
  17.  
    sns.set_context("notebook", font_scale=1.5,
  18.  
    rc={"lines.linewidth": 2.5})
  19.  
    digits = load_digits()
  20.  
    # We first reorder the data points according to the handwritten numbers.
  21.  
    X = np.vstack([digits.data[digits.target==i]
  22.  
    for i in range(10)])
  23.  
    y = np.hstack([digits.target[digits.target==i]
  24.  
    for i in range(10)])
  25.  
    digits_proj = TSNE(random_state=RS).fit_transform(X)
  26.  
     
  27.  
    def scatter(x, colors):
  28.  
    # We choose a color palette with seaborn.
  29.  
    palette = np.array(sns.color_palette("hls", 10))
  30.  
     
  31.  
    # We create a scatter plot.
  32.  
    f = plt.figure(figsize=(8, 8))
  33.  
    ax = plt.subplot(aspect='equal')
  34.  
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,
  35.  
    c=palette[colors.astype(np.int)])
  36.  
    plt.xlim(-25, 25)
  37.  
    plt.ylim(-25, 25)
  38.  
    ax.axis('off')
  39.  
    ax.axis('tight')
  40.  
     
  41.  
    # We add the labels for each digit.
  42.  
    txts = []
  43.  
    for i in range(10):
  44.  
    # Position of each label.
  45.  
    xtext, ytext = np.median(x[colors == i, :], axis=0)
  46.  
    txt = ax.text(xtext, ytext, str(i), fontsize=24)
  47.  
    txt.set_path_effects([
  48.  
    PathEffects.Stroke(linewidth=5, foreground="w"),
  49.  
    PathEffects.Normal()])
  50.  
    txts.append(txt)
  51.  
     
  52.  
    return f, ax, sc, txts
  53.  
     
  54.  
    scatter(digits_proj, y)
  55.  
    plt.savefig('digits_tsne-generated.png', dpi=120)
  56.  
    plt.show()

 

 

可視化結果如下:

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM