(數據科學學習手札14)Mean-Shift聚類法簡單介紹及Python實現


不管之前介紹的K-means還是K-medoids聚類,都得事先確定聚類簇的個數,而且肘部法則也並不是萬能的,總會遇到難以抉擇的情況,而本篇將要介紹的Mean-Shift聚類法就可以自動確定k的個數,下面簡要介紹一下其算法流程:

  1.隨機確定樣本空間內一個半徑確定的高維球及其球心;

  2.求該高維球內質心,並將高維球的球心移動至該質心處;

  3.重復2,直到高維球內的密度隨着繼續的球心滑動變化低於設定的閾值,算法結束

具體的原理可以參考下面的地址,筆者讀完覺得說的比較明了易懂:

http://blog.csdn.net/google19890102/article/details/51030884

而在Python中,機器學習包sklearn中封裝有該算法,下面用一個簡單的示例來演示如何在Python中使用Mean-Shift聚類:

一、低維

from sklearn.cluster import MeanShift
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from matplotlib.pyplot import style
import numpy as np
'''設置繪圖風格'''
style.use('ggplot')
'''生成演示用樣本數據'''
data1 = np.random.normal(0,0.3,(1000,2))
data2 = np.random.normal(1,0.2,(1000,2))
data3 = np.random.normal(2,0.3,(1000,2))

data = np.concatenate((data1,data2,data3))

# data_tsne = TSNE(learning_rate=100).fit_transform(data)
'''搭建Mean-Shift聚類器'''
clf=MeanShift()
'''對樣本數據進行聚類'''
predicted=clf.fit_predict(data)
colors = [['red','green','blue','grey'][i] for i in predicted]
'''繪制聚類圖'''
plt.scatter(data[:,0],data[:,1],c=colors,s=10)
plt.title('Mean Shift')

二、高維

from sklearn.cluster import MeanShift
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from matplotlib.pyplot import style
import numpy as np
'''設置繪圖風格'''
style.use('ggplot')
'''生成演示用樣本數據'''
data1 = np.random.normal(0,0.3,(1000,6))
data2 = np.random.normal(1,0.2,(1000,6))
data3 = np.random.normal(2,0.3,(1000,6))

data = np.concatenate((data1,data2,data3))

data_tsne = TSNE(learning_rate=100).fit_transform(data)
'''搭建Mean-Shift聚類器'''
clf=MeanShift()
'''對樣本數據進行聚類'''
predicted=clf.fit_predict(data)
colors = [['red','green','blue','grey'][i] for i in predicted]
'''繪制聚類圖'''
plt.scatter(data_tsne[:,0],data_tsne[:,1],c=colors,s=10)
plt.title('Mean Shift')

 

三、實際生活中的復雜數據

我們以之前一篇關於K-means聚類的實戰中使用到的重慶美團商戶數據為例,進行Mean-Shift聚類:

 

import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift
from sklearn.manifold import TSNE
import pandas as pd
import numpy as np
from matplotlib.pyplot import style

style.use('ggplot')

data = pd.read_excel(r'C:\Users\windows\Desktop\重慶美團商家信息.xlsx')
input = pd.DataFrame({'score':data['商家評分'][data['數據所屬期'] == data.iloc[0,0]],
                      'comment':data['商家評論數'][data['數據所屬期'] == data.iloc[0,0]],
                      'sales':data['本月銷售額'][data['數據所屬期'] == data.iloc[0,0]]})

'''去缺省值'''
input = input.dropna()

input_tsne = TSNE(learning_rate=100).fit_transform(input)

'''創造色彩列表'''
with open(r'C:\Users\windows\Desktop\colors.txt','r') as cc:
    col = cc.readlines()
col = [col[i][:7] for i in range(len(col)) if col[i][0] == '#']

'''進行Mean-Shift聚類'''
clf = MeanShift()
cl = clf.fit_predict(input)

'''繪制聚類結果'''
np.random.shuffle(col)
plt.scatter(input_tsne[:,0],input_tsne[:,1],c=[col[i] for i in cl],s=8)
plt.title('Mean-Shift Cluster of {}'.format(str(len(set(cl)))))

 

 

可見在實際工作中的復雜數據用Mean-Shift來聚類因為無法控制k個值,可能會產生過多的類而導致聚類失去意義,但Mean-Shift在圖像分割上用處很大。

 

以上便是本篇對Mean-Shift簡單的介紹,如有錯誤望指出。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM