k近鄰算法的Java實現


k近鄰算法是機器學習算法中最簡單的算法之一,工作原理是:存在一個樣本數據集合,即訓練樣本集,並且樣本集中的每個數據都存在標簽,即我們知道樣本集中每一數據和所屬分類的對應關系。輸入沒有標簽的新數據之后,將新數據的每個特征和樣本集中數據對應的特征進行比較,然后算法提取樣本集中特征最相似數據的分類標簽作為新數據的標簽。一般來說,我們只選取樣本數據中前k個最相似的數據。

Java實現:

KNNData.java

package KNN;

public class KNNData implements Comparable<KNNData>{
    double c1;
    double c2;
    double c3;
    double distance;
    String type;
    
    public KNNData(double c1, double c2, double c3, String type) {
        this.c1 = c1;
        this.c2 = c2;
        this.c3 = c3;
        this.type = type;
    }
    
    @Override
    public int compareTo(KNNData arg0) {
        return Double.valueOf(this.distance).compareTo(Double.valueOf(arg0.distance));
    }    
}

KNN.java

package KNN;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class KNN {
    
    //訓練集
    private List<KNNData> KNNDS = null;
    
    public KNN(List<KNNData> KNNDS) {
        this.KNNDS = KNNDS;
    }
    
    //歐式距離
    private static double disCal(KNNData i, KNNData td) {
        return Math.sqrt((i.c1 - td.c1)*(i.c1 - td.c1)+(i.c2 - td.c2)*(i.c2 - td.c2)+
                (i.c3 - td.c3)*(i.c3 - td.c3));
    }
    
    private static String getMaxValueKey(int k, List<KNNData> ts){
        //只保留前k個元素
        
        while(ts.size() != k) {
            ts.remove(k);
        }
                
        String sKey;
        //保存key以及出現次數
        HashMap<String,Integer> keySet = new HashMap<String,Integer>();
        keySet.put(ts.get(0).type,1);
        for (int x = 1; x < ts.size(); x++) {
            sKey = ts.get(x).type;
            if (keySet.containsKey(sKey)) {
                keySet.put(sKey, keySet.get(sKey)+1);
            } else {
                keySet.put(sKey, 1);
            }
        }
        Set<Map.Entry<String,Integer>> set = keySet.entrySet();
        Iterator<Map.Entry<String,Integer>> iter = set.iterator(); 
        
        int mValue = 0;
        String mType = "";
        while (iter.hasNext()){
            Map.Entry<String,Integer> map = iter.next();
            if (mValue < map.getValue()) {
                mType = map.getKey();
                mValue = map.getValue();
            }
        }
        
        return mType;
    }
    
    public static String knnCal(int k, KNNData i, List<KNNData> ts) {
        //保存距離
        for (KNNData td : ts) {
            td.distance = disCal(i, td);
        }
        Collections.sort(ts);    
        return getMaxValueKey(k, ts);
    }
}

KNNTest.java

package KNN;

import java.util.ArrayList;
import java.util.List;

public class KNNTest {

    public static void main(String[] args) {
        
        List<KNNData> kd = new ArrayList<KNNData>();
        //訓練集
        kd.add(new KNNData(1.2,1.1,0.1,"A"));
        kd.add(new KNNData(1.2,1.1,0.1,"A"));
        kd.add(new KNNData(7,1.5,0.1,"B"));
        kd.add(new KNNData(6,1.2,0.1,"B"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(100,1.1,0.1,"D"));

        System.out.println(KNN.knnCal(3, new KNNData(1.1,1.1,0.1,"N/A"), kd));
    }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM