K-Means K均值聚類 python代碼實現


本代碼參考自: https://github.com/lawlite19/MachineLearning_Python/blob/master/K-Means/K-Menas.py 

1.  初始化類中心,從樣本中隨機選取K個點作為初始的聚類中心點

def kMeansInitCentroids(X,K):
    m = X.shape[0]
    m_arr = np.arange(0,m)      # 生成0-m-1
    centroids = np.zeros((K,X.shape[1]))
    np.random.shuffle(m_arr)    # 打亂m_arr順序    
    rand_indices = m_arr[:K]    # 取前K個
    centroids = X[rand_indices,:]
    return centroids

2.  找出每個樣本離哪一個類中心的距離最近,並返回

def findClosestCentroids(x,inital_centroids):
    m = x.shape[0]   #樣本的個數 
    k = inital_centroids.shape[0]  #類別的數目
    dis = np.zeros((m,k))   # 存儲每個點到k個類的距離
    idx = np.zeros((m,1))   # 要返回的每條數據屬於哪個類別 
    
    """計算每個點到每個類的中心的距離"""
    for i in range(m):
        for j in range(k):
            dis[i,j] = np.dot((x[i,:] - inital_centroids[j,:]).reshape(1,-1),
                              (x[i,:] - inital_centroids[j,:]).reshape(-1,1))
    '''返回dis每一行的最小值對應的列號,即為對應的類別
    - np.min(dis, axis=1)  返回每一行的最小值
    - np.where(dis == np.min(dis, axis=1).reshape(-1,1)) 返回對應最小值的坐標
     - 注意:可能最小值對應的坐標有多個,where都會找出來,所以返回時返回前m個需要的即可(因為對於多個最小值,
     屬於哪個類別都可以)
    ''' 
    dummy,idx = np.where(dis == np.min(dis,axis=1).reshape(-1,1))
    return idx[0:dis.shape[0]]   

3. 更新類中心

def computerCentroids(x,idx,k):
    n = x.shape[1]   #每個樣本的維度
    centroids = np.zeros((k,n))   #定義每個中心點的形狀,其中維度和每個樣本的維度一樣
    for i in range(k):
        # 索引要是一維的, axis=0為每一列,idx==i一次找出屬於哪一類的,然后計算均值
        centroids[i,:] = np.mean(x[np.ravel(idx==i),:],axis=0).reshape(1,-1)
    return centroids

4. K-Means算法實現

def runKMeans(x,initial_centroids,max_iters,plot_process):
    m,n = x.shape    #樣本的個數和維度
    k = initial_centroids.shape[0]   #聚類的類數
    centroids = initial_centroids   #記錄當前類別的中心
    previous_centroids = centroids   #記錄上一次類別的中心
    idx = np.zeros((m,1))    #每條數據屬於哪個類
    
    for i in range(max_iters): 
        print("迭代計算次數:%d"%(i+1))
        idx = findClosestCentroids(x,centroids)
        if plot_process:    # 如果繪制圖像
            plt = plotProcessKMeans(X,centroids,previous_centroids,idx) # 畫聚類中心的移動過程
            previous_centroids = centroids  # 重置 
            plt.show()
        centroids = computerCentroids(x,idx,k)   #重新計算類中心
    return centroids,idx   #返回聚類中心和數據屬於哪個類別 

5. 繪制聚類中心的移動過程

def plotProcessKMeans(X,centroids,previous_centroids,idx):
    for i in range(len(idx)):
        if idx[i] == 0:
            plt.scatter(X[i,0], X[i,1],c="r")     # 原數據的散點圖 二維形式 
        elif idx[i] == 1:
            plt.scatter(X[i,0],X[i,1],c="b")
        else:
            plt.scatter(X[i,0],X[i,1],c="g")
    plt.plot(previous_centroids[:,0],previous_centroids[:,1],'rx',markersize=10,linewidth=5.0)  # 上一次聚類中心
    plt.plot(centroids[:,0],centroids[:,1],'rx',markersize=10,linewidth=5.0)                    # 當前聚類中心
    for j in range(centroids.shape[0]): # 遍歷每個類,畫類中心的移動直線
        p1 = centroids[j,:]
        p2 = previous_centroids[j,:]
        plt.plot([p1[0],p2[0]],[p1[1],p2[1]],"->",linewidth=2.0)
    return plt

6. 主程序實現

if __name__ == "__main__":
    print("聚類過程展示....\n")
    data = spio.loadmat("./data/data.mat")
    X = data['X']
    K = 3 
    initial_centroids = kMeansInitCentroids(X,K)
    max_iters = 10 
    runKMeans(X,initial_centroids,max_iters,True) 

7. 結果

聚類過程展示....

迭代計算次數:1

迭代計算次數:2

迭代計算次數:3

迭代計算次數:4

迭代計算次數:5

迭代計算次數:6

迭代計算次數:7

迭代計算次數:8

迭代計算次數:9

迭代計算次數:10

 
         
         
        
 
        

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM