前一陣子有一個學弟問kmeans算法的初始中心點怎么選,有沒有什么算法。我讓他看看kmeans++,結果學弟說有地方沒看懂。然后,他不懂的地方,我給標注了一下。
下面是網上的資料,我對畫線的地方做了標注。
k-means++算法選擇初始seeds的基本思想就是:初始的聚類中心之間的相互距離要盡可能的遠。wiki上對該算法的描述如下:
- 從輸入的數據點集合中隨機選擇一個點作為第一個聚類中心
- 對於數據集中的每一個點x,計算它與最近聚類中心(指已選擇的聚類中心)的距離D(x)
- 選擇一個新的數據點作為新的聚類中心,選擇的原則是:D(x)較大的點,被選取作為聚類中心的概率較大
- 重復2和3直到k個聚類中心被選出來
- 利用這k個初始的聚類中心來運行標准的k-means算法
從上面的算法描述上可以看到,算法的關鍵是第3步,如何將D(x)反映到點被選擇的概率上,一種算法如下:
- 先從我們的數據庫隨機挑個隨機點當“種子點”
- 對於每個點,我們都計算其和最近的一個“種子點”的距離D(x)並保存在一個數組里,然后把這些距離加起來得到Sum(D(x))。
- 然后,再取一個隨機值,用權重的方式來取計算下一個“種子點”。這個算法的實現是,先取一個能落在Sum(D(x))中的隨機值Random,然后用Random -= D(x),直到其<=0,此時的點就是下一個“種子點”。
- 這個Random 可以這么取: Random = Sum(D(x)) * 乘以0至1之間的一個小數
- 之所以取一個能落在Sum(D(x))中是值是因為,Random是隨機的,那么他有更大的機率落在D(x)值較大的區域里。如下圖,Random有更大的機率落在D(x3)中。
- Random -= D(x) 的意義在於找出 當前Random到底落在了哪個區間。

從上圖可以看出,假設Random落在D(x3)這個區間內,“然后用Random -= D(x),直到其<=0"此時找到的點就是D(x3),就是這步的中心點。
- 重復2和3直到k個聚類中心被選出來
- 利用這k個初始的聚類中心來運行標准的k-means算法
其實這種算法還是對照着代碼看比較清楚。下面粘個python的kmeans++
from math import pi, sin, cos
from collections import namedtuple
from random import random, choice
from copy import copy
try:
import psyco
psyco.full()
except ImportError:
pass
FLOAT_MAX = 1e100
class Point:
__slots__ = ["x", "y", "group"]
def __init__(self, x=0.0, y=0.0, group=0):
self.x, self.y, self.group = x, y, group
def generate_points(npoints, radius):
points = [Point() for _ in xrange(npoints)]
# note: this is not a uniform 2-d distribution
for p in points:
r = random() * radius
ang = random() * 2 * pi
p.x = r * cos(ang)
p.y = r * sin(ang)
return points
def nearest_cluster_center(point, cluster_centers):
"""Distance and index of the closest cluster center"""
def sqr_distance_2D(a, b):
return (a.x - b.x) ** 2 + (a.y - b.y) ** 2
min_index = point.group
min_dist = FLOAT_MAX
for i, cc in enumerate(cluster_centers):
d = sqr_distance_2D(cc, point)
if min_dist > d:
min_dist = d
min_index = i
return (min_index, min_dist)
def kpp(points, cluster_centers):
cluster_centers[0] = copy(choice(points))
d = [0.0 for _ in xrange(len(points))]
for i in xrange(1, len(cluster_centers)):
sum = 0
for j, p in enumerate(points):
d[j] = nearest_cluster_center(p, cluster_centers[:i])[1]
sum += d[j]
sum *= random()
for j, di in enumerate(d):
sum -= di
if sum > 0:
continue
cluster_centers[i] = copy(points[j])
break
for p in points:
p.group = nearest_cluster_center(p, cluster_centers)[0]
def lloyd(points, nclusters):
cluster_centers = [Point() for _ in xrange(nclusters)]
# call k++ init
kpp(points, cluster_centers)
lenpts10 = len(points) >> 10
changed = 0
while True:
# group element for centroids are used as counters
for cc in cluster_centers:
cc.x = 0
cc.y = 0
cc.group = 0
for p in points:
cluster_centers[p.group].group += 1
cluster_centers[p.group].x += p.x
cluster_centers[p.group].y += p.y
for cc in cluster_centers:
cc.x /= cc.group
cc.y /= cc.group
# find closest centroid of each PointPtr
changed = 0
for p in points:
min_i = nearest_cluster_center(p, cluster_centers)[0]
if min_i != p.group:
changed += 1
p.group = min_i
# stop when 99.9% of points are good
if changed <= lenpts10:
break
for i, cc in enumerate(cluster_centers):
cc.group = i
return cluster_centers
def print_eps(points, cluster_centers, W=400, H=400):
Color = namedtuple("Color", "r g b");
colors = []
for i in xrange(len(cluster_centers)):
colors.append(Color((3 * (i + 1) % 11) / 11.0,
(7 * i % 11) / 11.0,
(9 * i % 11) / 11.0))
max_x = max_y = -FLOAT_MAX
min_x = min_y = FLOAT_MAX
for p in points:
if max_x < p.x: max_x = p.x
if min_x > p.x: min_x = p.x
if max_y < p.y: max_y = p.y
if min_y > p.y: min_y = p.y
scale = min(W / (max_x - min_x),
H / (max_y - min_y))
cx = (max_x + min_x) / 2
cy = (max_y + min_y) / 2
print "%%!PS-Adobe-3.0\n%%%%BoundingBox: -5 -5 %d %d" % (W + 10, H + 10)
print ("/l {rlineto} def /m {rmoveto} def\n" +
"/c { .25 sub exch .25 sub exch .5 0 360 arc fill } def\n" +
"/s { moveto -2 0 m 2 2 l 2 -2 l -2 -2 l closepath " +
" gsave 1 setgray fill grestore gsave 3 setlinewidth" +
" 1 setgray stroke grestore 0 setgray stroke }def")
for i, cc in enumerate(cluster_centers):
print ("%g %g %g setrgbcolor" %
(colors[i].r, colors[i].g, colors[i].b))
for p in points:
if p.group != i:
continue
print ("%.3f %.3f c" % ((p.x - cx) * scale + W / 2,
(p.y - cy) * scale + H / 2))
print ("\n0 setgray %g %g s" % ((cc.x - cx) * scale + W / 2,
(cc.y - cy) * scale + H / 2))
print "\n%%%%EOF"
def main():
npoints = 30000
k = 7 # # clusters
points = generate_points(npoints, 10)
cluster_centers = lloyd(points, k)
print_eps(points, cluster_centers)
main()
