本代碼參考自: https://github.com/lawlite19/MachineLearning_Python/blob/master/K-Means/K-Menas.py
1. 初始化類中心,從樣本中隨機選取K個點作為初始的聚類中心點
def kMeansInitCentroids(X,K):
m = X.shape[0]
m_arr = np.arange(0,m) # 生成0-m-1
centroids = np.zeros((K,X.shape[1]))
np.random.shuffle(m_arr) # 打亂m_arr順序
rand_indices = m_arr[:K] # 取前K個
centroids = X[rand_indices,:]
return centroids
2. 找出每個樣本離哪一個類中心的距離最近,並返回
def findClosestCentroids(x,inital_centroids):
m = x.shape[0] #樣本的個數
k = inital_centroids.shape[0] #類別的數目
dis = np.zeros((m,k)) # 存儲每個點到k個類的距離
idx = np.zeros((m,1)) # 要返回的每條數據屬於哪個類別
"""計算每個點到每個類的中心的距離"""
for i in range(m):
for j in range(k):
dis[i,j] = np.dot((x[i,:] - inital_centroids[j,:]).reshape(1,-1),
(x[i,:] - inital_centroids[j,:]).reshape(-1,1))
'''返回dis每一行的最小值對應的列號,即為對應的類別
- np.min(dis, axis=1) 返回每一行的最小值
- np.where(dis == np.min(dis, axis=1).reshape(-1,1)) 返回對應最小值的坐標
- 注意:可能最小值對應的坐標有多個,where都會找出來,所以返回時返回前m個需要的即可(因為對於多個最小值,
屬於哪個類別都可以)
'''
dummy,idx = np.where(dis == np.min(dis,axis=1).reshape(-1,1))
return idx[0:dis.shape[0]]
3. 更新類中心
def computerCentroids(x,idx,k):
n = x.shape[1] #每個樣本的維度
centroids = np.zeros((k,n)) #定義每個中心點的形狀,其中維度和每個樣本的維度一樣
for i in range(k):
# 索引要是一維的, axis=0為每一列,idx==i一次找出屬於哪一類的,然后計算均值
centroids[i,:] = np.mean(x[np.ravel(idx==i),:],axis=0).reshape(1,-1)
return centroids
4. K-Means算法實現
def runKMeans(x,initial_centroids,max_iters,plot_process):
m,n = x.shape #樣本的個數和維度
k = initial_centroids.shape[0] #聚類的類數
centroids = initial_centroids #記錄當前類別的中心
previous_centroids = centroids #記錄上一次類別的中心
idx = np.zeros((m,1)) #每條數據屬於哪個類
for i in range(max_iters):
print("迭代計算次數:%d"%(i+1))
idx = findClosestCentroids(x,centroids)
if plot_process: # 如果繪制圖像
plt = plotProcessKMeans(X,centroids,previous_centroids,idx) # 畫聚類中心的移動過程
previous_centroids = centroids # 重置
plt.show()
centroids = computerCentroids(x,idx,k) #重新計算類中心
return centroids,idx #返回聚類中心和數據屬於哪個類別
5. 繪制聚類中心的移動過程
def plotProcessKMeans(X,centroids,previous_centroids,idx):
for i in range(len(idx)):
if idx[i] == 0:
plt.scatter(X[i,0], X[i,1],c="r") # 原數據的散點圖 二維形式
elif idx[i] == 1:
plt.scatter(X[i,0],X[i,1],c="b")
else:
plt.scatter(X[i,0],X[i,1],c="g")
plt.plot(previous_centroids[:,0],previous_centroids[:,1],'rx',markersize=10,linewidth=5.0) # 上一次聚類中心
plt.plot(centroids[:,0],centroids[:,1],'rx',markersize=10,linewidth=5.0) # 當前聚類中心
for j in range(centroids.shape[0]): # 遍歷每個類,畫類中心的移動直線
p1 = centroids[j,:]
p2 = previous_centroids[j,:]
plt.plot([p1[0],p2[0]],[p1[1],p2[1]],"->",linewidth=2.0)
return plt
6. 主程序實現
if __name__ == "__main__":
print("聚類過程展示....\n")
data = spio.loadmat("./data/data.mat")
X = data['X']
K = 3
initial_centroids = kMeansInitCentroids(X,K)
max_iters = 10
runKMeans(X,initial_centroids,max_iters,True)
7. 結果
聚類過程展示.... 迭代計算次數:1

迭代計算次數:2

迭代計算次數:3

迭代計算次數:4

迭代計算次數:5

迭代計算次數:6

迭代計算次數:7

迭代計算次數:8

迭代計算次數:9

迭代計算次數:10

