不管之前介紹的K-means還是K-medoids聚類,都得事先確定聚類簇的個數,而且肘部法則也並不是萬能的,總會遇到難以抉擇的情況,而本篇將要介紹的Mean-Shift聚類法就可以自動確定k的個數,下面簡要介紹一下其算法流程:
1.隨機確定樣本空間內一個半徑確定的高維球及其球心;
2.求該高維球內質心,並將高維球的球心移動至該質心處;
3.重復2,直到高維球內的密度隨着繼續的球心滑動變化低於設定的閾值,算法結束
具體的原理可以參考下面的地址,筆者讀完覺得說的比較明了易懂:
http://blog.csdn.net/google19890102/article/details/51030884
而在Python中,機器學習包sklearn中封裝有該算法,下面用一個簡單的示例來演示如何在Python中使用Mean-Shift聚類:
一、低維
from sklearn.cluster import MeanShift import matplotlib.pyplot as plt from sklearn.manifold import TSNE from matplotlib.pyplot import style import numpy as np '''設置繪圖風格''' style.use('ggplot') '''生成演示用樣本數據''' data1 = np.random.normal(0,0.3,(1000,2)) data2 = np.random.normal(1,0.2,(1000,2)) data3 = np.random.normal(2,0.3,(1000,2)) data = np.concatenate((data1,data2,data3)) # data_tsne = TSNE(learning_rate=100).fit_transform(data) '''搭建Mean-Shift聚類器''' clf=MeanShift() '''對樣本數據進行聚類''' predicted=clf.fit_predict(data) colors = [['red','green','blue','grey'][i] for i in predicted] '''繪制聚類圖''' plt.scatter(data[:,0],data[:,1],c=colors,s=10) plt.title('Mean Shift')
二、高維
from sklearn.cluster import MeanShift import matplotlib.pyplot as plt from sklearn.manifold import TSNE from matplotlib.pyplot import style import numpy as np '''設置繪圖風格''' style.use('ggplot') '''生成演示用樣本數據''' data1 = np.random.normal(0,0.3,(1000,6)) data2 = np.random.normal(1,0.2,(1000,6)) data3 = np.random.normal(2,0.3,(1000,6)) data = np.concatenate((data1,data2,data3)) data_tsne = TSNE(learning_rate=100).fit_transform(data) '''搭建Mean-Shift聚類器''' clf=MeanShift() '''對樣本數據進行聚類''' predicted=clf.fit_predict(data) colors = [['red','green','blue','grey'][i] for i in predicted] '''繪制聚類圖''' plt.scatter(data_tsne[:,0],data_tsne[:,1],c=colors,s=10) plt.title('Mean Shift')
三、實際生活中的復雜數據
我們以之前一篇關於K-means聚類的實戰中使用到的重慶美團商戶數據為例,進行Mean-Shift聚類:
import matplotlib.pyplot as plt from sklearn.cluster import MeanShift from sklearn.manifold import TSNE import pandas as pd import numpy as np from matplotlib.pyplot import style style.use('ggplot') data = pd.read_excel(r'C:\Users\windows\Desktop\重慶美團商家信息.xlsx') input = pd.DataFrame({'score':data['商家評分'][data['數據所屬期'] == data.iloc[0,0]], 'comment':data['商家評論數'][data['數據所屬期'] == data.iloc[0,0]], 'sales':data['本月銷售額'][data['數據所屬期'] == data.iloc[0,0]]}) '''去缺省值''' input = input.dropna() input_tsne = TSNE(learning_rate=100).fit_transform(input) '''創造色彩列表''' with open(r'C:\Users\windows\Desktop\colors.txt','r') as cc: col = cc.readlines() col = [col[i][:7] for i in range(len(col)) if col[i][0] == '#'] '''進行Mean-Shift聚類''' clf = MeanShift() cl = clf.fit_predict(input) '''繪制聚類結果''' np.random.shuffle(col) plt.scatter(input_tsne[:,0],input_tsne[:,1],c=[col[i] for i in cl],s=8) plt.title('Mean-Shift Cluster of {}'.format(str(len(set(cl)))))
可見在實際工作中的復雜數據用Mean-Shift來聚類因為無法控制k個值,可能會產生過多的類而導致聚類失去意義,但Mean-Shift在圖像分割上用處很大。
以上便是本篇對Mean-Shift簡單的介紹,如有錯誤望指出。