k近鄰算法-java實現


最近在看《機器學習實戰》這本書,因為自己本身很想深入的了解機器學習算法,加之想學python,就在朋友的推薦之下選擇了這本書進行學習。 

一 . K-近鄰算法(KNN)概述 

    最簡單最初級的分類器是將全部的訓練數據所對應的類別都記錄下來,當測試對象的屬性和某個訓練對象的屬性完全匹配時,便可以對其進行分類。但是怎么可能所有測試對象都會找到與之完全匹配的訓練對象呢,其次就是存在一個測試對象同時與多個訓練對象匹配,導致一個訓練對象被分到了多個類的問題,基於這些問題呢,就產生了KNN。

     KNN是通過測量不同特征值之間的距離進行分類。它的的思路是:如果一個樣本在特征空間中的k個最相似(即特征空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。K通常是不大於20的整數。KNN算法中,所選擇的鄰居都是已經正確分類的對象。該方法在定類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。

     下面通過一個簡單的例子說明一下:如下圖,綠色圓要被決定賦予哪個類,是紅色三角形還是藍色四方形?如果K=3,由於紅色三角形所占比例為2/3,綠色圓將被賦予紅色三角形那個類,如果K=5,由於藍色四方形比例為3/5,因此綠色圓被賦予藍色四方形類。

 

由此也說明了KNN算法的結果很大程度取決於K的選擇。

     在KNN中,通過計算對象間距離來作為各個對象之間的非相似性指標,避免了對象之間的匹配問題,在這里距離一般使用歐氏距離或曼哈頓距離:

                      

同時,KNN通過依據k個對象中占優的類別進行決策,而不是單一的對象類別決策。這兩點就是KNN算法的優勢。

   接下來對KNN算法的思想總結一下:就是在訓練集中數據和標簽已知的情況下,輸入測試數據,將測試數據的特征與訓練集中對應的特征進行相互比較,找到訓練集中與之最為相似的前K個數據,則該測試數據對應的類別就是K個數據中出現次數最多的那個分類,其算法的描述為:

1)計算測試數據與各個訓練數據之間的距離;

2)按照距離的遞增關系進行排序;

3)選取距離最小的K個點;

4)確定前K個點所在類別的出現頻率;

5)返回前K個點中出現頻率最高的類別作為測試數據的預測分類。

代碼實現:

 

import java.util.*;

/**
 * code by me
 * <p>
 * Data:2017/8/17 Time:16:40
 * User:lbh
 */
public class KNN {

    /**
     * KNN數據模型
     */
    public static class KNNModel implements Comparable<KNNModel> {
        public double a;
        public double b;
        public double c;
        public double distince;
        String type;

        public KNNModel(double a, double b, double c, String type) {
            this.a = a;
            this.b = b;
            this.c = c;
            this.type = type;
        }
        /**
         * 按距離排序
         *
         * @param arg
         * @return
         */
        @Override
        public int compareTo(KNNModel arg) {
            return Double.valueOf(this.distince).compareTo(Double.valueOf(arg.distince));
        }
    }

    /**
     * 計算距離
     *
     * @param knnModelList
     * @param i
     */
    private static void calDistince(List<KNNModel> knnModelList, KNNModel i) {
        double distince;
        for (KNNModel m : knnModelList) {
            distince = Math.sqrt((i.a - m.a) * (i.a - m.a) + (i.b - m.b) * (i.b - m.b) + (i.c - m.c) * (i.c - m.c));
            m.distince = distince;
        }
    }

    /**
     * 找出前k個數據中分類最多的數據
     *
     * @param knnModelList
     * @return
     */
    private static String findMostData(List<KNNModel> knnModelList) {
        Map<String, Integer> typeCountMap = new HashMap<String, Integer>();
        String type = "";
        Integer tempVal = 0;
        // 統計分類個數
        for (KNNModel model : knnModelList) {
            if (typeCountMap.containsKey(model.type)) {
                typeCountMap.put(model.type, typeCountMap.get(model.type) + 1);
            } else {
                typeCountMap.put(model.type, 1);
            }
        }
        // 找出最多分類
        for (Map.Entry<String, Integer> entry : typeCountMap.entrySet()) {
            if (entry.getValue() > tempVal) {
                tempVal = entry.getValue();
                type = entry.getKey();
            }
        }
        return type;
    }

    /**
     * KNN 算法的實現
     *
     * @param k
     * @param knnModelList
     * @param inputModel
     * @return
     */
    public static String calKNN(int k, List<KNNModel> knnModelList, KNNModel inputModel) {
        System.out.println("1.計算距離");
        calDistince(knnModelList, inputModel);
        System.out.println("2.按距離(近-->遠)排序");
        Collections.sort(knnModelList);
        System.out.println("3.取前k個數據");
        while (knnModelList.size() > k) {
            knnModelList.remove(k);
        }
        System.out.println("4.找出前k個數據中分類出現頻率最大的數據");
        String type = findMostData(knnModelList);
        return type;
    }

    /**
     * 測試KNN算法
     *
     * @param args
     */
    public static void main(String[] args) {
        // 准備數據
        List<KNNModel> knnModelList = new ArrayList<KNNModel>();
        knnModelList.add(new KNNModel(1.1, 1.1, 1.1, "A"));
        knnModelList.add(new KNNModel(1.2, 1.1, 1.0, "A"));
        knnModelList.add(new KNNModel(1.1, 1.0, 1.0, "A"));
        knnModelList.add(new KNNModel(3.0, 3.1, 1.0, "B"));
        knnModelList.add(new KNNModel(3.1, 3.0, 1.0, "B"));
        knnModelList.add(new KNNModel(5.4, 6.0, 4.0, "C"));
        knnModelList.add(new KNNModel(5.5, 6.3, 4.1, "C"));
        knnModelList.add(new KNNModel(6.0, 6.0, 4.0, "C"));
        knnModelList.add(new KNNModel(10.0, 12.0, 10.0, "M"));
        // 預測數據
        KNNModel predictionData = new KNNModel(5.1, 6.2, 2.0, "NB");
        // 計算
        String result = calKNN(3, knnModelList, predictionData);
        System.out.println("預測結果:"+result);
    }
}

 

結果:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM