作為聚類的代表算法,k-means本屬於NP難問題,通過迭代優化的方式,可以求解出近似解。
偽代碼如下:
1,算法部分
距離采用歐氏距離。參數默認值隨意選的。
import numpy as np def k_means(x,k=4,epochs=500,delta=1e-3): # 隨機選取k個樣本點作為中心 indices=np.random.randint(0,len(x),size=k) centers=x[indices] # 保存分類結果 results=[] for i in range(k): results.append([]) step=1 flag=True while flag: if step>epochs: return centers,results else: # 合適的位置清空 for i in range(k): results[i]=[] # 將所有樣本划分到離它最近的中心簇 for i in range(len(x)): current=x[i] min_dis=np.inf tmp=0 for j in range(k): distance=dis(current,centers[j]) if distance<min_dis: min_dis=distance tmp=j results[tmp].append(current) # 更新中心 for i in range(k): old_center=centers[i] new_center=np.array(results[i]).mean(axis=0) # 如果新,舊中心不等,更新 # if not (old_center==new_center).all(): if dis(old_center,new_center)>delta: centers[i]=new_center flag=False if flag: break # 需要更新flag重設為True else: flag=True step+=1 return centers,results def dis(x,y): return np.sqrt(np.sum(np.power(x-y,2)))
2,驗證
我隨機出了一些平面上的點,然后對其分類。
x=np.random.randint(0,50,size=100) y=np.random.randint(0,50,size=100) z=np.array(list(zip(x,y))) import matplotlib.pyplot as plt %matplotlib inline plt.plot(x,y,'ro')
首先看看未分類之前的,當然也是跟分類后的分布是一樣的。
然后看看分類后的結果:
centers,results=k_means(z) color=['ko','go','bo','yo'] for i in range(len(results)): result=results[i] plt.plot([res[0] for res in result],[res[1] for res in result],color[i]) plt.plot([res[0] for res in centers],[res[1] for res in centers],'ro') plt.show()
可以看出,4個分類還是挺合理的。
再增加k=5試試,多執行幾次看看。
centers,results=k_means(z,k=5) color=['ko','go','bo','yo','co'] for i in range(len(results)): result=results[i] plt.plot([res[0] for res in result],[res[1] for res in result],color[i]) plt.plot([res[0] for res in centers],[res[1] for res in centers],'ro') plt.show()
可以看出,此算法對初值很敏感。
_^v^_