python3實現Kmeans++算法


零:環境

python 3.6.5

JetBrains PyCharm 2018.1.4 x64

一:KMeans算法大致思路

  KMeans算法是機器學習中的一種無監督聚類算法,是針對不具有類型的數據進行分類的一種算法

  形象的來說可以說成是給定一組點data,給定要分類的簇數k,來求中心點和對應的簇的集合

  中心點所在的簇中的其他點都是距離該中心點最近的點,因而才在一個簇里

  具體步驟

  1、首先在點集中隨機尋找k個點來當作中心點

  2、然后初始化k個集合,用於存放對應的簇的對象

  3、開始KMeans算法的一輪。計算第i個點到k個中心點的距離[l1,l2,l3,……,ln],然后記錄下距離最短的中心點,並將該點加入到對應的簇集合中

  4、全部點都計算完之后開始計算每個簇內的所有點的中心點,即取各個維度上的平均值的點作為新的中心點

  5、計算所有新舊中心點的距離的平方的和,看是否為0,不為0則繼續循環或遞歸

  6、重復第3,4,5步驟,直到循環或遞歸跳出

  可以看出步驟還是非常簡單明了的

  關於第5步為什么是0,因為當簇的分類趨於穩定的時候,各個簇之間應當沒有數據的擺動。什么是數據的擺動呢?就是簇中的某個數上一次歸屬於簇A,這回歸屬於簇B,反復變化的情況即為擺動。

  對於KMeans算法來說是不存在的,因為新的中心點是簇內點集的中心點,所以當簇內穩定時新中心點也是穩定的,所以可以以0作為判斷條件

  因為KMeans++算法與KMeans算法區別非常小,所以在討論完KMeans++算法之后再一起發代碼

二:KMeans++的思路

  KMeans++算法實際就是修改了KMeans算法的第一步操作

  之所以進行這樣的優化,是為了讓隨機選取的中心點不再只是趨於局部最優解,而是讓其盡可能的趨於全局最優解。要注意“盡可能”的三個字,即使是正常的KMeans++算法也無法保證百分百全局最優,在說取值原理之后我們就能知道為什么了

  思路就是我們要盡可能的保證各個簇的中心點的距離要盡可能的遠

  當簇的中心盡可能的遠的時候就能夠盡可能的保證中心點之間不會在同一個簇內

  KMeans的迭代實際上就是簇的形狀的修改,只要初始形狀不太出格就會回歸於正確形狀

  具體步驟如下

  1、首先隨機尋找一個點作為中心點

  2、然后計算其他點到目前的全部簇中心點的距離(最開始只有一個中心點)

  3、計算出映射到對應點的概率

\[\frac{{D{{(k)}^2}}}{{\sum\limits_{i = 0}^{\rm{m}} {D{{(i)}^2}} }}\]

  其中D(k)就是第k個點到其他中心點的最短距離,注意還有平方

  4、根據這個概率來利用輪盤法隨機出一個中心點作為下一個中心點,然后重復2,3,4步驟直至找到全部中心點

  我們可以看出即使是KMeans++算法也只是概率性的選擇,所以還是不穩定的,但是實際效果上已經比原有的隨機選取K值好多了,當然最好的還是人工根據數據手動選取中心點

  以下是參考代碼

  1 import csv
  2 import math
  3 import random
  4 from functools import reduce
  5 import matplotlib.pyplot as plt
  6 import numpy
  7 
  8 #   KMeans++算法,優化后的KMeans的算法
  9 class KMeansPP():
 10     def __init__(self,pBasePoints,pN = 5,pPointsCSVName = "kmeans_points.csv",pSetsCSVName = "kmeas_sets.csv"):
 11         """
 12         初始化KMeans++算法的構造函數
 13         :param pBasePoints: 所要計算的數據,為點的二維數組
 14         :param pN: 要分成的簇的個數
 15         :param pPointsCSVName: 要寫入的點集的CSV文件
 16         :param pSetsCSVName: 要寫入的簇的CSV文件
 17         """
 18         self.__N = pN
 19         self.__PCSVName = pPointsCSVName
 20         self.__SCSVName = pSetsCSVName
 21         self.__M = len(pBasePoints)#數據的個數
 22         self.__basePoints = pBasePoints
 23 
 24         self.__initBaseCenterPoint()   #kmeans++算法初始化中心點
 25         #self.__centerPoints = random.sample(self.__basePoints,self.__N) #kmeans算法初始化中心點
 26         self.__initSetsAndNewCenter()#初始化簇集合
 27         pass
 28 
 29     #   初始化N個點
 30     #   這里改進為Kmeans++算法
 31     def __initBaseCenterPoint(self):
 32         self.__centerPoints = []
 33         self.__centerPoints.append(self.__basePoints[random.randint(0, self.__M - 1)])#   首先初始化一個中心點
 34         while len(self.__centerPoints) < self.__N:#添加中心點直到N個
 35             tempDX = [min([KMeansPP.f_dAB(a,b) for b in self.__centerPoints])**2 for a in self.__basePoints]#D(x)的平方的列表。這一步中的a是遍歷了所有的點,然后將a再分別與中心點集合進行遍歷求出兩點距離求出最短距離
 36             DXSum = sum(tempDX)#kmeans++公式中的分母
 37             DXP = []#輪盤法的值域范圍計算,從開始的0到最后的1
 38             for i in range(len(tempDX)):
 39                 if i == 0:
 40                     DXP.append(tempDX[0]/DXSum)
 41                 else:
 42                     DXP.append(DXP[i-1]+tempDX[i]/DXSum)
 43             #   因為中心點到其他中心點的最短距離必定是0,所以必定不會選中中心點
 44             self.__centerPoints.append(self.__basePoints[KMeansPP.f_Roulette(DXP)])
 45         pass
 46 
 47     #   初始化新中心點和中心點集合
 48     def __initSetsAndNewCenter(self):
 49         self.__sets = {k:[] for k in self.__centerPoints}
 50         self.__newCenterPoints = []
 51 
 52     #   計算新的中心點
 53     def __countNewCenterPoints(self):
 54         self.__newCenterPoints = []
 55         pDim = len(self.__basePoints[0])
 56         for i in range(self.__N):#重新計算每個簇的中心點
 57             tp = self.__sets[self.__centerPoints[i]]#獲取簇集合
 58             point = tuple([sum([p[i] for p in tp])/len(tp) for i in range(pDim)])#計算新的點。先i遍歷維度,然后遍歷每個點,對每個點的維度i取出來作為集合再求平均值。實際上就是矩陣的轉置
 59             self.__newCenterPoints.append(point)
 60         pass
 61 
 62     #   求AB距離
 63     @staticmethod
 64     def f_dAB(A,B):
 65         dim = min(len(A),len(B))
 66         return sum([(A[i] - B[i]) ** 2 for i in range(dim)]) ** 0.5
 67 
 68     #   輪盤法,返回下標
 69     @staticmethod
 70     def f_Roulette(_list):
 71         tr = random.random()
 72         for i in range(len(_list)):
 73             if i == 0 and _list[i] > tr:
 74                 return 0
 75             else:
 76                 if _list[i] > tr and _list[i - 1] <= tr:
 77                     return i
 78 
 79     #   划分集合,kmeans算法
 80     def __kmeans(self):
 81 
 82         #   {其他點:[這個點到N個中心點的距離],……}
 83         t_dList = {b:[KMeansPP.f_dAB(a, b) for a in self.__centerPoints] for b in self.__basePoints}#先遍歷b為其他點,a為中心點。計算點b到其他所有的中心點的距離
 84         for k,v in t_dList.items():
 85             self.__sets[self.__centerPoints[v.index(min(v))]].append(k)#將距離最小的添加到對應的簇里
 86 
 87         self.__countNewCenterPoints()#計算新中心點
 88         #   當各個簇之間有點變動時,就繼續
 89         if sum([KMeansPP.f_dAB(self.__centerPoints[i],self.__newCenterPoints[i]) for i in range(self.__N)]) > 0:
 90             self.__centerPoints = self.__newCenterPoints[:]#把新中心點作為中心點
 91             self.__initSetsAndNewCenter()#重置集合和新中心點
 92             self.k_means()#遞歸調用
 93         pass
 94 
 95     #   k_means算法的對外接口
 96     def k_means(self):
 97         self.__kmeans()
 98         return self.__sets,self.__centerPoints
 99 
100     def writeToCSV(self):
101         with open(self.__SCSVName,"w",newline="") as fpc:
102             fpcWriter = csv.writer(fpc)
103             fpcWriter.writerow(self.__centerPoints)
104             maxIndex = max([len(v) for k, v in self.__sets.items()])
105             fpcWriter.writerows([[v[i] if len(v) > i else "" for (k, v) in self.__sets.items()] for i in range(maxIndex)])
106             pass
107 
108         with open(self.__PCSVName,"w",newline="") as fpp:
109             fppWriter = csv.writer(fpp)
110             fppWriter.writerows([[self.__basePoints[i*10 + j] if i*10+j < self.__M else "" for j in range(10)] for i in range(self.__M//10)])
111             pass
112         pass
kmeans與kmeans++代碼

 


 

  本文原創,轉載請注明出處https://www.cnblogs.com/dofstar/p/11341494.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM