手寫k-means算法


作為聚類的代表算法,k-means本屬於NP難問題,通過迭代優化的方式,可以求解出近似解。

偽代碼如下:

 

 1,算法部分

距離采用歐氏距離。參數默認值隨意選的。

import numpy as np
def k_means(x,k=4,epochs=500,delta=1e-3):
#     隨機選取k個樣本點作為中心
    indices=np.random.randint(0,len(x),size=k)
    centers=x[indices]
#     保存分類結果
    results=[]
    for i in range(k):
        results.append([])
    step=1
    flag=True
    while flag:
        if step>epochs:
            return centers,results
        else:
#             合適的位置清空
            for i in range(k):
                results[i]=[]
#         將所有樣本划分到離它最近的中心簇
        for i in range(len(x)):
            current=x[i]
            min_dis=np.inf
            tmp=0
            for j in range(k):
                distance=dis(current,centers[j])
                if distance<min_dis:
                    min_dis=distance
                    tmp=j
            results[tmp].append(current)
#     更新中心
        for i in range(k):
            old_center=centers[i]
            new_center=np.array(results[i]).mean(axis=0)
#             如果新,舊中心不等,更新
#             if not (old_center==new_center).all():
            if dis(old_center,new_center)>delta:
                centers[i]=new_center
                flag=False
        if flag:
            break
        # 需要更新flag重設為True
        else:
            flag=True
        step+=1
    return centers,results
                
def dis(x,y):
    return np.sqrt(np.sum(np.power(x-y,2)))

2,驗證

我隨機出了一些平面上的點,然后對其分類。

x=np.random.randint(0,50,size=100)
y=np.random.randint(0,50,size=100)
z=np.array(list(zip(x,y)))

import matplotlib.pyplot as plt
%matplotlib inline

plt.plot(x,y,'ro')

首先看看未分類之前的,當然也是跟分類后的分布是一樣的。

然后看看分類后的結果:

centers,results=k_means(z)

color=['ko','go','bo','yo']
for i in range(len(results)):
    result=results[i]
    plt.plot([res[0] for res in result],[res[1] for res in result],color[i])
plt.plot([res[0] for res in centers],[res[1] for res in centers],'ro')
plt.show()

 可以看出,4個分類還是挺合理的。

再增加k=5試試,多執行幾次看看。

centers,results=k_means(z,k=5)

color=['ko','go','bo','yo','co']
for i in range(len(results)):
    result=results[i]
    plt.plot([res[0] for res in result],[res[1] for res in result],color[i])
plt.plot([res[0] for res in centers],[res[1] for res in centers],'ro')
plt.show()

 

 

 

 可以看出,此算法對初值很敏感。

   _^v^_

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM