K-Means算法的Java實現


K-means算法是硬聚類算法,是典型的基於原型的目標函數聚類方法的代表,它是數據點到原型的某種距離作為優化的目標函數,利用函數求極值的方法得到迭代運算的調整規則。K-means算法以歐式距離作為相似度測度,它是求對應某一初始聚類中心向量V最優分類,使得評價指標J最小。算法采用誤差平方和准則函數作為聚類准則函數。

package com.coshaho.learn.kmeans;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;

public class KMeans 
{
    //聚類的數目
    final static int ClassCount = 3;
    //樣本數目(測試集)
    final static int InstanceNumber = 150; 
    //樣本屬性數目(測試)
    final static int FieldCount = 5;
    
    //設置異常點閾值參數(每一類初始的最小數目為InstanceNumber/ClassCount^t)
    final static double t = 2.0;
    
    //存放數據的矩陣
    private float[][] data;
    //每個類的均值中心
    private float[][] classData;
    //噪聲集合索引
    private ArrayList<Integer> noises;
    //存放每次變換結果的矩陣
    private ArrayList<ArrayList<Integer>> result;
    
    public KMeans()
    {
        //最后一位用來儲存結果
        data = new float[InstanceNumber][FieldCount+1];
        classData = new float[ClassCount][FieldCount];
        result = new ArrayList<ArrayList<Integer>>(ClassCount);
        noises = new ArrayList<Integer>();
    }
    
    public void readData(String TrainDataFile)
    {
        FileReader fr = null;
        BufferedReader br = null;
        try
        {
            fr = new FileReader(TrainDataFile);
            br = new BufferedReader(fr);
            //存放數據的臨時變量
            String lineData = null;
            String[] splitData = null;
            int line = 0;
            while( br.ready())
            {
                lineData = br.readLine();
                System.out.println(lineData);
                splitData = lineData.split(",");
                for(int i = 0 ; i < splitData.length ;i++)
                {
                    data[line][i] = Float.parseFloat(splitData[i]);
                }
                line++;
            }
        }
        catch(Exception e)
        {
            e.printStackTrace();
        }
        finally
        {
            if(null != br)
            {
                try 
                {
                    br.close();
                } 
                catch (IOException e) 
                {
                    e.printStackTrace();
                }
            }
            if(null != fr)
            {
                try 
                {
                    fr.close();
                } 
                catch (IOException e) 
                {
                    e.printStackTrace();
                }
            }
        }
    }
    
    public void cluster()
    {
        //數據歸一化
        normalize();
        
        //標記是否需要重新找初始點
        boolean needUpdataInitials = true;
       
        //找初始點的迭代次數
        int times = 1;
        
        //找初始點
        while(needUpdataInitials)
        {
            needUpdataInitials = false;
            result.clear();
            System.out.println("Find Initials Iteration"+(times++)+"time(s)");
      
            //一次找初始點的嘗試和根據初始點的分類
            findInitials();
            firstClassify();
      
            for(int i = 0;i < result.size();i++)
            {
                if(result.get(i).size() < InstanceNumber/Math.pow(ClassCount,t))
                {
                    needUpdataInitials = true;
                    noises.addAll(result.get(i));
                }
            }
        }
       
        Adjust();
    }
 
    /**
     * 數據歸一化
     * @author coshaho
     */
    private void normalize()
    {
        // 計算數據每個維度最大值max
        float[] max = new float[FieldCount];
        for(int i = 0;i < InstanceNumber;i++)
        {
            for(int j = 0;j < FieldCount;j++)
            {
                if(data[i][j] > max[j])
                {
                    max[j] = data[i][j];
                }
            }
        }
       
        // 每個維度歸一化值=原始值/max
        for(int i = 0;i < InstanceNumber;i++)
        {
            for(int j = 0;j < FieldCount;j++)
            {
                data[i][j] = data[i][j]/max[j];
            }
        }
    }

    /**
     * 尋找初始聚類中心
     * @author coshaho
     */
    private void findInitials() 
    {
        int i, j, a, b;
        i = j = a = b = 0;
        float maxDis = 0;
        int alreadyCls = 2;

        // 選取距離最遠的兩個點a,b作為聚類中心點
        ArrayList<Integer> initials = new ArrayList<Integer>();
        for (; i < InstanceNumber; i++) 
        {
            // 噪聲點不參與計算
            if (noises.contains(i))
            {
                continue;
            }
            j = i + 1;
            for (; j < InstanceNumber; j++) 
            {
                // 噪聲點不參與計算
                if (noises.contains(j))
                {
                    continue;
                }
                float newDis = calDis(data[i], data[j]);
                if (maxDis < newDis) 
                {
                    a = i;
                    b = j;
                    maxDis = newDis;
                }
            }
        }

        // initials添加初始聚類中心點序號a,b
        initials.add(a);
        initials.add(b);
        
        // classData添加聚類中心點data[a],data[b]
        classData[0] = data[a];
        classData[1] = data[b];

        // 新增兩個聚類,並添加聚類成員
        ArrayList<Integer> resultOne = new ArrayList<Integer>();
        ArrayList<Integer> resultTwo = new ArrayList<Integer>();
        resultOne.add(a);
        resultTwo.add(b);
        result.add(resultOne);
        result.add(resultTwo);

        // 1、計算剩下每個點x與其他點的最小距離l,並記錄Map<x,l>
        // 2、選取Map<x,l>中的最大l,並以對應的點x作為新的聚類中心
        while (alreadyCls < ClassCount) 
        {
            i = j = 0;
            float maxMin = 0;
            int newClass = -1;

            for (; i < InstanceNumber; i++) 
            {
                float min = 0;
                float newMin = 0;
                if (initials.contains(i))
                {
                    continue;
                }
                if (noises.contains(i))
                {
                    continue;
                }
                for (j = 0; j < alreadyCls; j++) 
                {
                    newMin = calDis(data[i], classData[j]);
                    if (min == 0 || newMin < min)
                    {
                        min = newMin;
                    }
                }
                if (min > maxMin) 
                {
                    maxMin = min;
                    newClass = i;
                }
            }
            
            // initials添加新的聚類中心點序號newClass
            initials.add(newClass);
            
            // classData添加新的聚類中心點data[newClass]
            classData[alreadyCls++] = data[newClass];
            
            // 新增一個聚類,並添加成員
            ArrayList<Integer> rslt = new ArrayList<Integer>();
            rslt.add(newClass);
            result.add(rslt);
        }
    }

    /**
     * 首次聚類分配
     * 點x到哪個聚類中心點最近,則划分到哪個聚類
     * @author coshaho
     */
    public void firstClassify() 
    {
        for (int i = 0; i < InstanceNumber; i++) 
        {
            float min = 0f;
            int clsId = -1;
            for (int j = 0; j < classData.length; j++) 
            {
                // 歐式距離
                float newMin = calDis(classData[j], data[i]);
                if (clsId == -1 || newMin < min) 
                {
                    clsId = j;
                    min = newMin;
                }
            }
            
            if (!result.get(clsId).contains(i))
            {
                result.get(clsId).add(i);
            }
        }
    }

    // 迭代分類,直到各個類的數據不再變化
    public void Adjust() 
    {
        // 記錄是否發生變化
        boolean change = true;

        // 循環的次數
        int times = 1;
        while (change) 
        {
            // 復位
            change = false;
            System.out.println("Adjust Iteration" + (times++) + "time(s)");

            // 重新計算每個類的均值
            for (int i = 0; i < ClassCount; i++) 
            {
                // 原有的數據
                ArrayList<Integer> cls = result.get(i);

                // 新的均值
                float[] newMean = new float[FieldCount];

                // 計算均值
                for (Integer index : cls) 
                {
                    for (int j = 0; j < FieldCount; j++)
                        newMean[j] += data[index][j];
                }
                for (int j = 0; j < FieldCount; j++)
                {
                    newMean[j] /= cls.size();
                }
                if (!compareMean(newMean, classData[i])) 
                {
                    classData[i] = newMean;
                    change = true;
                }
            }
            // 清空之前的數據
            for (ArrayList<Integer> cls : result)
            {
                cls.clear();
            }

            // 重新分配
            for (int i = 0; i < InstanceNumber; i++) 
            {
                float min = 0f;
                int clsId = -1;
                for (int j = 0; j < classData.length; j++) 
                {
                    float newMin = calDis(classData[j], data[i]);
                    if (clsId == -1 || newMin < min) 
                    {
                        clsId = j;
                        min = newMin;
                    }
                }
                data[i][FieldCount] = clsId;
                result.get(clsId).add(i);
            }
        }
    }

    /**
     * 計算a樣本和b樣本的歐式距離作為不相似度
     * 
     * @param a 樣本a
     * @param b 樣本b
     * @return 歐式距離長度
     */
    private float calDis(float[] aVector, float[] bVector) {
        double dis = 0;
        int i = 0;
        /* 最后一個數據在訓練集中為結果,所以不考慮 */
        for (; i < aVector.length; i++)
            dis += Math.pow(bVector[i] - aVector[i], 2);
        dis = Math.pow(dis, 0.5);
        return (float) dis;
    }

    /**
     * 判斷兩個均值向量是否相等
     * 
     * @param a 向量a
     * @param b 向量b
     * @return
     */
    private boolean compareMean(float[] a, float[] b) {
        if (a.length != b.length)
            return false;
        for (int i = 0; i < a.length; i++) {
            if (a[i] > 0 && b[i] > 0 && a[i] != b[i]) {
                return false;
            }
        }
        return true;
    }

    /**
     * 將結果輸出到一個文件中
     * 
     * @param fileName
     */
    public void printResult(String fileName) 
    {
        FileWriter fw = null;
        BufferedWriter bw = null;
        try 
        {
            fw = new FileWriter(fileName);
            bw = new BufferedWriter(fw);
            // 寫入文件
            for (int i = 0; i < InstanceNumber; i++) 
            {
                bw.write(String.valueOf(data[i][FieldCount]).substring(0, 1));
                bw.newLine();
            }

            // 統計每類的數目,打印到控制台
            for (int i = 0; i < ClassCount; i++) 
            {
                System.out.println("第" + (i + 1) + "類數目: "
                        + result.get(i).size());
            }
        } 
        catch (IOException e) 
        {
            e.printStackTrace();
        } 
        finally 
        {

            // 關閉資源
            if (bw != null)
            {
                try 
                {
                    bw.close();
                } 
                catch (IOException e) 
                {
                    e.printStackTrace();
                }
            }
            if (fw != null)
            {
                try 
                {
                    fw.close();
                } 
                catch (IOException e) 
                {
                    e.printStackTrace();
                }
            }
        }
    }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM