K-means算法是硬聚類算法,是典型的基於原型的目標函數聚類方法的代表,它是數據點到原型的某種距離作為優化的目標函數,利用函數求極值的方法得到迭代運算的調整規則。K-means算法以歐式距離作為相似度測度,它是求對應某一初始聚類中心向量V最優分類,使得評價指標J最小。算法采用誤差平方和准則函數作為聚類准則函數。
package com.coshaho.learn.kmeans; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; public class KMeans { //聚類的數目 final static int ClassCount = 3; //樣本數目(測試集) final static int InstanceNumber = 150; //樣本屬性數目(測試) final static int FieldCount = 5; //設置異常點閾值參數(每一類初始的最小數目為InstanceNumber/ClassCount^t) final static double t = 2.0; //存放數據的矩陣 private float[][] data; //每個類的均值中心 private float[][] classData; //噪聲集合索引 private ArrayList<Integer> noises; //存放每次變換結果的矩陣 private ArrayList<ArrayList<Integer>> result; public KMeans() { //最后一位用來儲存結果 data = new float[InstanceNumber][FieldCount+1]; classData = new float[ClassCount][FieldCount]; result = new ArrayList<ArrayList<Integer>>(ClassCount); noises = new ArrayList<Integer>(); } public void readData(String TrainDataFile) { FileReader fr = null; BufferedReader br = null; try { fr = new FileReader(TrainDataFile); br = new BufferedReader(fr); //存放數據的臨時變量 String lineData = null; String[] splitData = null; int line = 0; while( br.ready()) { lineData = br.readLine(); System.out.println(lineData); splitData = lineData.split(","); for(int i = 0 ; i < splitData.length ;i++) { data[line][i] = Float.parseFloat(splitData[i]); } line++; } } catch(Exception e) { e.printStackTrace(); } finally { if(null != br) { try { br.close(); } catch (IOException e) { e.printStackTrace(); } } if(null != fr) { try { fr.close(); } catch (IOException e) { e.printStackTrace(); } } } } public void cluster() { //數據歸一化 normalize(); //標記是否需要重新找初始點 boolean needUpdataInitials = true; //找初始點的迭代次數 int times = 1; //找初始點 while(needUpdataInitials) { needUpdataInitials = false; result.clear(); System.out.println("Find Initials Iteration"+(times++)+"time(s)"); //一次找初始點的嘗試和根據初始點的分類 findInitials(); firstClassify(); for(int i = 0;i < result.size();i++) { if(result.get(i).size() < InstanceNumber/Math.pow(ClassCount,t)) { needUpdataInitials = true; noises.addAll(result.get(i)); } } } Adjust(); } /** * 數據歸一化 * @author coshaho */ private void normalize() { // 計算數據每個維度最大值max float[] max = new float[FieldCount]; for(int i = 0;i < InstanceNumber;i++) { for(int j = 0;j < FieldCount;j++) { if(data[i][j] > max[j]) { max[j] = data[i][j]; } } } // 每個維度歸一化值=原始值/max for(int i = 0;i < InstanceNumber;i++) { for(int j = 0;j < FieldCount;j++) { data[i][j] = data[i][j]/max[j]; } } } /** * 尋找初始聚類中心 * @author coshaho */ private void findInitials() { int i, j, a, b; i = j = a = b = 0; float maxDis = 0; int alreadyCls = 2; // 選取距離最遠的兩個點a,b作為聚類中心點 ArrayList<Integer> initials = new ArrayList<Integer>(); for (; i < InstanceNumber; i++) { // 噪聲點不參與計算 if (noises.contains(i)) { continue; } j = i + 1; for (; j < InstanceNumber; j++) { // 噪聲點不參與計算 if (noises.contains(j)) { continue; } float newDis = calDis(data[i], data[j]); if (maxDis < newDis) { a = i; b = j; maxDis = newDis; } } } // initials添加初始聚類中心點序號a,b initials.add(a); initials.add(b); // classData添加聚類中心點data[a],data[b] classData[0] = data[a]; classData[1] = data[b]; // 新增兩個聚類,並添加聚類成員 ArrayList<Integer> resultOne = new ArrayList<Integer>(); ArrayList<Integer> resultTwo = new ArrayList<Integer>(); resultOne.add(a); resultTwo.add(b); result.add(resultOne); result.add(resultTwo); // 1、計算剩下每個點x與其他點的最小距離l,並記錄Map<x,l> // 2、選取Map<x,l>中的最大l,並以對應的點x作為新的聚類中心 while (alreadyCls < ClassCount) { i = j = 0; float maxMin = 0; int newClass = -1; for (; i < InstanceNumber; i++) { float min = 0; float newMin = 0; if (initials.contains(i)) { continue; } if (noises.contains(i)) { continue; } for (j = 0; j < alreadyCls; j++) { newMin = calDis(data[i], classData[j]); if (min == 0 || newMin < min) { min = newMin; } } if (min > maxMin) { maxMin = min; newClass = i; } } // initials添加新的聚類中心點序號newClass initials.add(newClass); // classData添加新的聚類中心點data[newClass] classData[alreadyCls++] = data[newClass]; // 新增一個聚類,並添加成員 ArrayList<Integer> rslt = new ArrayList<Integer>(); rslt.add(newClass); result.add(rslt); } } /** * 首次聚類分配 * 點x到哪個聚類中心點最近,則划分到哪個聚類 * @author coshaho */ public void firstClassify() { for (int i = 0; i < InstanceNumber; i++) { float min = 0f; int clsId = -1; for (int j = 0; j < classData.length; j++) { // 歐式距離 float newMin = calDis(classData[j], data[i]); if (clsId == -1 || newMin < min) { clsId = j; min = newMin; } } if (!result.get(clsId).contains(i)) { result.get(clsId).add(i); } } } // 迭代分類,直到各個類的數據不再變化 public void Adjust() { // 記錄是否發生變化 boolean change = true; // 循環的次數 int times = 1; while (change) { // 復位 change = false; System.out.println("Adjust Iteration" + (times++) + "time(s)"); // 重新計算每個類的均值 for (int i = 0; i < ClassCount; i++) { // 原有的數據 ArrayList<Integer> cls = result.get(i); // 新的均值 float[] newMean = new float[FieldCount]; // 計算均值 for (Integer index : cls) { for (int j = 0; j < FieldCount; j++) newMean[j] += data[index][j]; } for (int j = 0; j < FieldCount; j++) { newMean[j] /= cls.size(); } if (!compareMean(newMean, classData[i])) { classData[i] = newMean; change = true; } } // 清空之前的數據 for (ArrayList<Integer> cls : result) { cls.clear(); } // 重新分配 for (int i = 0; i < InstanceNumber; i++) { float min = 0f; int clsId = -1; for (int j = 0; j < classData.length; j++) { float newMin = calDis(classData[j], data[i]); if (clsId == -1 || newMin < min) { clsId = j; min = newMin; } } data[i][FieldCount] = clsId; result.get(clsId).add(i); } } } /** * 計算a樣本和b樣本的歐式距離作為不相似度 * * @param a 樣本a * @param b 樣本b * @return 歐式距離長度 */ private float calDis(float[] aVector, float[] bVector) { double dis = 0; int i = 0; /* 最后一個數據在訓練集中為結果,所以不考慮 */ for (; i < aVector.length; i++) dis += Math.pow(bVector[i] - aVector[i], 2); dis = Math.pow(dis, 0.5); return (float) dis; } /** * 判斷兩個均值向量是否相等 * * @param a 向量a * @param b 向量b * @return */ private boolean compareMean(float[] a, float[] b) { if (a.length != b.length) return false; for (int i = 0; i < a.length; i++) { if (a[i] > 0 && b[i] > 0 && a[i] != b[i]) { return false; } } return true; } /** * 將結果輸出到一個文件中 * * @param fileName */ public void printResult(String fileName) { FileWriter fw = null; BufferedWriter bw = null; try { fw = new FileWriter(fileName); bw = new BufferedWriter(fw); // 寫入文件 for (int i = 0; i < InstanceNumber; i++) { bw.write(String.valueOf(data[i][FieldCount]).substring(0, 1)); bw.newLine(); } // 統計每類的數目,打印到控制台 for (int i = 0; i < ClassCount; i++) { System.out.println("第" + (i + 1) + "類數目: " + result.get(i).size()); } } catch (IOException e) { e.printStackTrace(); } finally { // 關閉資源 if (bw != null) { try { bw.close(); } catch (IOException e) { e.printStackTrace(); } } if (fw != null) { try { fw.close(); } catch (IOException e) { e.printStackTrace(); } } } } }