kmeans++


前一陣子有一個學弟問kmeans算法的初始中心點怎么選,有沒有什么算法。我讓他看看kmeans++,結果學弟說有地方沒看懂。然后,他不懂的地方,我給標注了一下。

下面是網上的資料,我對畫線的地方做了標注。

      k-means++算法選擇初始seeds的基本思想就是:初始的聚類中心之間的相互距離要盡可能的遠。wiki上對該算法的描述如下:

  1. 從輸入的數據點集合中隨機選擇一個點作為第一個聚類中心
  2. 對於數據集中的每一個點x,計算它與最近聚類中心(指已選擇的聚類中心)的距離D(x)
  3. 選擇一個新的數據點作為新的聚類中心,選擇的原則是:D(x)較大的點,被選取作為聚類中心的概率較大
  4. 重復2和3直到k個聚類中心被選出來
  5. 利用這k個初始的聚類中心來運行標准的k-means算法

 從上面的算法描述上可以看到,算法的關鍵是第3步,如何將D(x)反映到點被選擇的概率上,一種算法如下:

  1. 先從我們的數據庫隨機挑個隨機點當“種子點”
  2. 對於每個點,我們都計算其和最近的一個“種子點”的距離D(x)並保存在一個數組里,然后把這些距離加起來得到Sum(D(x))。
  3. 然后,再取一個隨機值,用權重的方式來取計算下一個“種子點”。這個算法的實現是,先取一個能落在Sum(D(x))中的隨機值Random然后用Random -= D(x),直到其<=0,此時的點就是下一個“種子點”。
    • 這個Random 可以這么取: Random = Sum(D(x)) * 乘以0至1之間的一個小數
    • 之所以取一個能落在Sum(D(x))中是值是因為,Random是隨機的,那么他有更大的機率落在D(x)值較大的區域里。如下圖,Random有更大的機率落在D(x3)中。
    • Random -= D(x) 的意義在於找出 當前Random到底落在了哪個區間。

      

      從上圖可以看出,假設Random落在D(x3)這個區間內,“然后用Random -= D(x),直到其<=0"此時找到的點就是D(x3),就是這步的中心點。

  1. 重復2和3直到k個聚類中心被選出來
  2. 利用這k個初始的聚類中心來運行標准的k-means算法

其實這種算法還是對照着代碼看比較清楚。下面粘個python的kmeans++

from math import pi, sin, cos
from collections import namedtuple
from random import random, choice
from copy import copy

try:
    import psyco
    psyco.full()
except ImportError:
    pass


FLOAT_MAX = 1e100


class Point:
    __slots__ = ["x", "y", "group"]
    def __init__(self, x=0.0, y=0.0, group=0):
        self.x, self.y, self.group = x, y, group


def generate_points(npoints, radius):
    points = [Point() for _ in xrange(npoints)]

    # note: this is not a uniform 2-d distribution
    for p in points:
        r = random() * radius
        ang = random() * 2 * pi
        p.x = r * cos(ang)
        p.y = r * sin(ang)

    return points

def nearest_cluster_center(point, cluster_centers):
    """Distance and index of the closest cluster center"""
    def sqr_distance_2D(a, b):
        return (a.x - b.x) ** 2  +  (a.y - b.y) ** 2

    min_index = point.group
    min_dist = FLOAT_MAX

    for i, cc in enumerate(cluster_centers):
        d = sqr_distance_2D(cc, point)
        if min_dist > d:
            min_dist = d
            min_index = i

    return (min_index, min_dist)


def kpp(points, cluster_centers):
    cluster_centers[0] = copy(choice(points))
    d = [0.0 for _ in xrange(len(points))]

    for i in xrange(1, len(cluster_centers)):
        sum = 0
        for j, p in enumerate(points):
            d[j] = nearest_cluster_center(p, cluster_centers[:i])[1]
            sum += d[j]

        sum *= random()

        for j, di in enumerate(d):
            sum -= di
            if sum > 0:
                continue
            cluster_centers[i] = copy(points[j])
            break

    for p in points:
        p.group = nearest_cluster_center(p, cluster_centers)[0]


def lloyd(points, nclusters):
    cluster_centers = [Point() for _ in xrange(nclusters)]

    # call k++ init
    kpp(points, cluster_centers)

    lenpts10 = len(points) >> 10

    changed = 0
    while True:
        # group element for centroids are used as counters
        for cc in cluster_centers:
            cc.x = 0
            cc.y = 0
            cc.group = 0

        for p in points:
            cluster_centers[p.group].group += 1
            cluster_centers[p.group].x += p.x
            cluster_centers[p.group].y += p.y

        for cc in cluster_centers:
            cc.x /= cc.group
            cc.y /= cc.group

        # find closest centroid of each PointPtr
        changed = 0
        for p in points:
            min_i = nearest_cluster_center(p, cluster_centers)[0]
            if min_i != p.group:
                changed += 1
                p.group = min_i

        # stop when 99.9% of points are good
        if changed <= lenpts10:
            break

    for i, cc in enumerate(cluster_centers):
        cc.group = i

    return cluster_centers


def print_eps(points, cluster_centers, W=400, H=400):
    Color = namedtuple("Color", "r g b");

    colors = []
    for i in xrange(len(cluster_centers)):
        colors.append(Color((3 * (i + 1) % 11) / 11.0,
                            (7 * i % 11) / 11.0,
                            (9 * i % 11) / 11.0))

    max_x = max_y = -FLOAT_MAX
    min_x = min_y = FLOAT_MAX

    for p in points:
        if max_x < p.x: max_x = p.x
        if min_x > p.x: min_x = p.x
        if max_y < p.y: max_y = p.y
        if min_y > p.y: min_y = p.y

    scale = min(W / (max_x - min_x),
                H / (max_y - min_y))
    cx = (max_x + min_x) / 2
    cy = (max_y + min_y) / 2

    print "%%!PS-Adobe-3.0\n%%%%BoundingBox: -5 -5 %d %d" % (W + 10, H + 10)

    print ("/l {rlineto} def /m {rmoveto} def\n" +
           "/c { .25 sub exch .25 sub exch .5 0 360 arc fill } def\n" +
           "/s { moveto -2 0 m 2 2 l 2 -2 l -2 -2 l closepath " +
           "   gsave 1 setgray fill grestore gsave 3 setlinewidth" +
           " 1 setgray stroke grestore 0 setgray stroke }def")

    for i, cc in enumerate(cluster_centers):
        print ("%g %g %g setrgbcolor" %
               (colors[i].r, colors[i].g, colors[i].b))

        for p in points:
            if p.group != i:
                continue
            print ("%.3f %.3f c" % ((p.x - cx) * scale + W / 2,
                                    (p.y - cy) * scale + H / 2))

        print ("\n0 setgray %g %g s" % ((cc.x - cx) * scale + W / 2,
                                        (cc.y - cy) * scale + H / 2))

    print "\n%%%%EOF"


def main():
    npoints = 30000
    k = 7 # # clusters

    points = generate_points(npoints, 10)
    cluster_centers = lloyd(points, k)
    print_eps(points, cluster_centers)


main()

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM