【十大算法實現之KNN】KNN算法實例(含測試數據和源碼)


KNN算法基本的思路是比較好理解的,今天根據它的特點寫了一個實例,我會把所有的數據和代碼都寫在下面供大家參考,不足之處,請指正。謝謝!

 

update:工程代碼全部在本頁面中,測試數據已丟失,建議去UCI Dataset中找一個自行測試一下。

 

幾點說明:

1.KNN中的K=5;

2.在計算權重時,采用的是減去函數{1,0.8,0.6,0.4,0.2},當然你也可以采用反函數或高斯函數;

3.5%作為測試集(decision.txt),95%作為訓練集(training.txt);

4.在計算costfun之前,對所有的屬性進行了歸一化,由於這里不知道數據集每個屬性代表的含義,所以就一視同仁,實際情況下,應該具體問題具體分析;

 

image

 

 

XBWKNN.java

package XBWKNN;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * KNN算法
 * @author XBW
 * @date 2014年8月16日
 */


public class XBWKNN{
    public final static int KofKNN=5;
    public final static double weight[]={1,0.9,0.7,0.4,0.1};                //減法函數y=1-0.2*x
    

    /**
     * knn
     * @param data
     * @param ds
     * @return ans
     */
    public static int knn(Data data,DataSet ds){
        int ans = 0;
        List<Data> dis=calcDis(data,ds);
        ans=calcKDis(data,dis);
        return ans;
    }
    
    /**
     * 計算訓練集中所有向量的距離,排序之后取前K個
     * @param data
     * @param ds
     * @return
     */
    @SuppressWarnings("null")
    public static List<Data>calcDis(Data data,DataSet ds){
        List<Data> anslist =new ArrayList<Data>();
        double dx1=data.x1;
        double dx2=data.x2;
        double dx3=data.x3;
        for(int i=0;i<ds.ds.size();i++){
            double x1=ds.ds.get(i).x1;
            double x2=ds.ds.get(i).x2;
            double x3=ds.ds.get(i).x3;
            ds.ds.get(i).costfun=Math.sqrt((dx1-x1)*(dx1-x1)+(dx2-x2)*(dx2-x2)+(dx3-x3)*(dx3-x3));
            anslist.add(ds.ds.get(i));
        }
        Collections.sort(anslist,new Comparator<Data>(){
               public int compare(Data o1, Data o2) {
                   Double s=o1.costfun-o2.costfun;
                   if(s<0)
                       return -1;
                   else
                       return 1; 
                }
        });
        return anslist;
    }
    
    
    /**
     * 按一定的權重計算出前K個
     * @param data
     * @param ds
     * @return
     */
    public static int calcKDis(Data data,List<Data> anslist){
        Double[] anstype={0.0,0.0,0.0,0.0};
        for(int i=0;i<KofKNN;i++){
            if(anslist.get(i).type==1){
                anstype[1]+=weight[i];
            }
            else if(anslist.get(i).type==2){
                anstype[2]+=weight[i];
            }
            if(anslist.get(i).type==3){
                anstype[3]+=weight[i];
            }
        }
        Double maxt=-1.0;
        int tag=1;
        for(int i=1;i<=3;i++){
            if(maxt<anstype[i]){
                tag=i;
                maxt=anstype[i];
            }
        }
        return tag;
    }
    
    public static void main(String[] args) throws IOException{
        DataSet ds=new DataSet();
        DataTest dt=new DataTest();
        
        int correct=0;
        for(int i=0;i<dt.dt.size();i++){
            Data data=dt.dt.get(i);
            int result=knn(data,ds);
            if(result==data.type){
                correct++;
            }
        }
        System.out.println("total test num :"+dt.dt.size());
        System.out.println("correct test num :"+correct);
        System.out.println("ratio :"+correct/(double)dt.dt.size());
    }
}

 

Datatest.java

package XBWKNN;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;




/**
 * 測試數據
 * @author XBW
 * @date 2014年8月16日
 */

public class DataTest{
    String defaultpath="D:\\MachineLearning\\十大算法\\KNN\\knncode\\decision.txt";
    List<Data> dt;
    
    @SuppressWarnings("null")
    public DataTest() throws IOException{
        List<Data> dset = new ArrayList<Data>();
        File ds=new File(defaultpath);
        @SuppressWarnings("resource")
        BufferedReader br = new BufferedReader(new FileReader(ds));
        String tsing;
        double max1=-1;
        double max2=-1;
        double max3=-1;
        while((tsing=br.readLine())!=null){
            String[] dlist=tsing.split("    ");
            Data data=new Data();
            data.x1=Double.parseDouble(dlist[0]);
            data.x2=Double.parseDouble(dlist[1]);
            data.x3=Double.parseDouble(dlist[2]);
            data.type=Integer.parseInt(dlist[3]);
            dset.add(data);
            
            if(data.x1>max1){
                max1=data.x1;
            }
            if(data.x2>max2){
                max2=data.x2;
            }
            if(data.x3>max3){
                max3=data.x3;
            }
        }
        dset=normalization(dset,max1,max2,max3);
        this.dt=dset;
    }
    
    public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){
        for(int i=0;i<dset.size();i++){
            dset.get(i).x1/=m1;
            dset.get(i).x2/=m2;
            dset.get(i).x3/=m3;
        }
        return dset;
    }
}

 

DataSet.java

package XBWKNN;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;




/**
 * 訓練數據
 * @author XBW
 * @date 2014年8月16日
 */

public class DataSet{
    String defaultpath="D:\\MachineLearning\\十大算法\\KNN\\knncode\\training.txt";
    List<Data> ds;
    
    @SuppressWarnings("null")
    public DataSet() throws IOException{
        List<Data> dset =new ArrayList<Data>();
        File ds=new File(defaultpath);
        @SuppressWarnings("resource")
        BufferedReader br = new BufferedReader(new FileReader(ds));
        String tsing;
        double max1=-1;
        double max2=-1;
        double max3=-1;
        while((tsing=br.readLine())!=null){
            String[] dlist=tsing.split("    ");
            Data data=new Data();
            data.x1=Double.parseDouble(dlist[0]);
            data.x2=Double.parseDouble(dlist[1]);
            data.x3=Double.parseDouble(dlist[2]);
            data.type=Integer.parseInt(dlist[3]);
            dset.add(data);
            
            if(data.x1>max1){
                max1=data.x1;
            }
            if(data.x2>max2){
                max2=data.x2;
            }
            if(data.x3>max3){
                max3=data.x3;
            }
        }
        dset=normalization(dset,max1,max2,max3);
        this.ds=dset;
    }
    
    public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){
        for(int i=0;i<dset.size();i++){
            dset.get(i).x1/=m1;
            dset.get(i).x2/=m2;
            dset.get(i).x3/=m3;
        }
        return dset;
    }
}

 

Data.java

package XBWKNN;

/**
 * 一條數據
 * @author XBW
 * @date 2014年8月16日
 */

public class Data{
    Double x1;
    Double x2;
    Double x3;
    Double costfun;
    int type;
}

 

output:

image


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM