【機器學習】K-means三維聚類,進階版,python


K-means是一種常用的聚類算法,進階版展示如下,代碼傳送門:

import random
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


# 正規化數據集 X
def normalize(X, axis=-1, p=2):
    lp_norm = np.atleast_1d(np.linalg.norm(X, p, axis))
    lp_norm[lp_norm == 0] = 1
    return X / np.expand_dims(lp_norm, axis)


# 計算一個樣本與數據集中所有樣本的歐氏距離的平方
def euclidean_distance(one_sample, X):
    one_sample = one_sample.reshape(1, -1)
    X = X.reshape(X.shape[0], -1)
    distances = np.power(np.tile(one_sample, (X.shape[0], 1)) - X, 2).sum(axis=1)
    return distances



class Kmeans():
    """Kmeans聚類算法.

    Parameters:
    -----------
    k: int
        聚類的數目.
    max_iterations: int
        最大迭代次數. 
    varepsilon: float
        判斷是否收斂, 如果上一次的所有k個聚類中心與本次的所有k個聚類中心的差都小於varepsilon, 
        則說明算法已經收斂
    """
    def __init__(self, k=2, max_iterations=500, varepsilon=0.0001):
        self.k = k
        self.max_iterations = max_iterations
        self.varepsilon = varepsilon

    # 從所有樣本中隨機選取self.k樣本作為初始的聚類中心
    def init_random_centroids(self, X):
        n_samples, n_features = np.shape(X)
        centroids = np.zeros((self.k, n_features))
        for i in range(self.k):
            centroid = X[np.random.choice(range(n_samples))]
            centroids[i] = centroid
        return centroids

    # 返回距離該樣本最近的一個中心索引[0, self.k)
    def _closest_centroid(self, sample, centroids):
        distances = euclidean_distance(sample, centroids)
        closest_i = np.argmin(distances)
        return closest_i

    # 將所有樣本進行歸類,歸類規則就是將該樣本歸類到與其最近的中心
    def create_clusters(self, centroids, X):
        n_samples = np.shape(X)[0]
        clusters = [[] for _ in range(self.k)]
        for sample_i, sample in enumerate(X):
            centroid_i = self._closest_centroid(sample, centroids)
            clusters[centroid_i].append(sample_i)
        return clusters

    # 對中心進行更新
    def update_centroids(self, clusters, X):
        n_features = np.shape(X)[1]
        centroids = np.zeros((self.k, n_features))
        for i, cluster in enumerate(clusters):
            centroid = np.mean(X[cluster], axis=0)
            centroids[i] = centroid
        return centroids

    # 將所有樣本進行歸類,其所在的類別的索引就是其類別標簽
    def get_cluster_labels(self, clusters, X):
        y_pred = np.zeros(np.shape(X)[0])
        for cluster_i, cluster in enumerate(clusters):
            for sample_i in cluster:
                y_pred[sample_i] = cluster_i
        return y_pred

    # 對整個數據集X進行Kmeans聚類,返回其聚類的標簽
    def predict(self, X):
        # 從所有樣本中隨機選取self.k樣本作為初始的聚類中心
        centroids = self.init_random_centroids(X)

        # 迭代,直到算法收斂(上一次的聚類中心和這一次的聚類中心幾乎重合)或者達到最大迭代次數
        for _ in range(self.max_iterations):
            # 將所有進行歸類,歸類規則就是將該樣本歸類到與其最近的中心
            clusters = self.create_clusters(centroids, X)
            former_centroids = centroids

            # 計算新的聚類中心
            centroids = self.update_centroids(clusters, X)

            # 如果聚類中心幾乎沒有變化,說明算法已經收斂,退出迭代
            diff = centroids - former_centroids
            if diff.any() < self.varepsilon:
                break

        return self.get_cluster_labels(clusters, X)


def main():
    # Load the dataset
    X, y = datasets.make_blobs(n_samples=10000, 
                               n_features=3, 
                               centers=[[3,3, 3], [0,0,0], [1,1,1], [2,2,2]], 
                               cluster_std=[0.2, 0.1, 0.2, 0.2], 
                               random_state =9)

    # 用Kmeans算法進行聚類
    clf = Kmeans(k=4)
    y_pred = clf.predict(X)


    # 可視化聚類效果
    fig = plt.figure(figsize=(12, 8))
    ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=30, azim=20)
    plt.scatter(X[y==0][:, 0], X[y==0][:, 1], X[y==0][:, 2])
    plt.scatter(X[y==1][:, 0], X[y==1][:, 1], X[y==1][:, 2])
    plt.scatter(X[y==2][:, 0], X[y==2][:, 1], X[y==2][:, 2])
    plt.scatter(X[y==3][:, 0], X[y==3][:, 1], X[y==3][:, 2])
    plt.show()


if __name__ == "__main__":
    main()

效果圖:

備注:本文代碼系非原創的,因需要做聚類,幾乎將博客里的關於這部分的代碼都嘗試了一遍,這份代碼是沒有報錯的,感恩大神。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM