K-Means算法的Python實現


算法簡介

K-Means是一種常用的聚類算法。聚類在機器學習分類中屬於無監督學習,在數據集沒有標注的情況下,便於對數據進行分群。而K-Means中的K即指將數據集分成K個子集合。

K-Means演示

從以下的動畫、視頻和計算過程可以較為直觀了解算法的計算過程。

動畫展示

wikipedia kmeans animation

視頻展示

https://youtu.be/BVFG7fd1H30

在線展示

kmeans測試頁面

使用場景

由於簡單和低維度下高效的特性,K-Means算法被應用在人群分類,圖像分段,文本分類以及數據挖掘前數據預處理場景中。

算法理解

計算流程

一下使用$$分隔的內容為LaTeX編碼的數學表達式,請自行解析。
假設有n個點$$x_{1}$$, $$x_{2}$$, $$x_{3}$$, ..., $$x_{n}$$ 以及子集數量K。

  • 步驟1 取出K個隨機向量作為中心點用於初始化

\[C = c_{1},c_{2},...,c_{k} \]

  • 步驟2 計算每個點$$x_{n}$$與K個中心點的距離,然后將每個點聚集到與之最近的中心點

\[\min_{c_{i} \in C} dist(c_{i},x) \]

dist函數用於實現歐式距離計算。

  • 步驟3 新的聚集出來之后,計算每個聚集的新中心點

\[c_{i} = avg(\sum_{x_{i} \in S_{i}} x_{i})​ \]

Si表示歸屬於第i個中心點的數據。

  • 步驟4 迭代步驟2和步驟3,直至滿足退出條件(中心點不再變化)

Python代碼實現

本代碼參考了https://mubaris.com/posts/kmeans-clustering/這篇博客, 用於聚類的數據集可從GitHub上下載到,下載的地址https://github.com/mubaris/friendly-fortnight/blob/master/xclara.csv

Python代碼如下:

導包,初始化圖形參數,導入樣例數據集

%matplotlib inline
from copy import deepcopy
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = (16, 9)
plt.style.use('ggplot')


# 導入數據集
data = pd.read_csv('xclara.csv')
# print(data.shape)
# data.head()

將數據集轉換為二維數組,並繪制二維坐標圖

# 將csv文件中的數據轉換為二維數組
f1 = data['V1'].values
f2 = data['V2'].values

X = np.array(list(zip(f1, f2)))
plt.scatter(f1, f2, c='black', s=6)

樣例點

定義距離計算函數

# 按行的方式計算兩個坐標點之間的距離
def dist(a, b, ax=1):
    return np.linalg.norm(a - b, axis=ax)

初始化分區數,隨機獲得初始中心點

# 設定分區數
k = 3
# 隨機獲得中心點的X軸坐標
C_x = np.random.randint(0, np.max(X)-20, size=k)
# 隨機獲得中心點的Y軸坐標
C_y = np.random.randint(0, np.max(X)-20, size=k)
C = np.array(list(zip(C_x, C_y)), dtype=np.float32)

將初始化中心點和樣例數據畫到同一個坐標系上

# 將初始化中心點畫到輸入的樣例數據上
plt.scatter(f1, f2, c='black', s=7)
plt.scatter(C_x, C_y, marker='*', s=200, c='red')

初始節點和樣例數據節點

實現K-Means中的核心迭代

# 用於保存中心點更新前的坐標
C_old = np.zeros(C.shape)
print(C)
# 用於保存數據所屬中心點
clusters = np.zeros(len(X))
# 迭代標識位,通過計算新舊中心點的距離
iteration_flag = dist(C, C_old, 1)

tmp = 1
# 若中心點不再變化或循環次數不超過20次(此限制可取消),則退出循環
while iteration_flag.any() != 0 and tmp < 20:
    # 循環計算出每個點對應的最近中心點
    for i in range(len(X)):
        # 計算出每個點與中心點的距離
        distances = dist(X[i], C, 1)
        # print(distances)
        # 記錄0 - k-1個點中距離近的點
        cluster = np.argmin(distances) 
        # 記錄每個樣例點與哪個中心點距離最近
        clusters[i] = cluster
        
    # 采用深拷貝將當前的中心點保存下來
    # print("the distinct of clusters: ", set(clusters))
    C_old = deepcopy(C)
    # 從屬於中心點放到一個數組中,然后按照列的方向取平均值
    for i in range(k):
        points = [X[j] for j in range(len(X)) if clusters[j] == i]
        # print(points)
        # print(np.mean(points, axis=0))
        C[i] = np.mean(points, axis=0)
        # print(C[i])
    # print(C)
    
    # 計算新舊節點的距離
    print ('循環第%d次' % tmp)
    tmp = tmp + 1
    iteration_flag = dist(C, C_old, 1)
    print("新中心點與舊點的距離:", iteration_flag)

將最終結果和樣例點畫到同一個坐標系上

# 最終結果圖示
colors = ['r', 'g', 'b', 'y', 'c', 'm']
fig, ax = plt.subplots()
# 不同的子集使用不同的顏色
for i in range(k):
        points = np.array([X[j] for j in range(len(X)) if clusters[j] == i])
        ax.scatter(points[:, 0], points[:, 1], s=7, c=colors[i])
ax.scatter(C[:, 0], C[:, 1], marker='*', s=200, c='black')

最終計算結果圖示


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM