KNN算法基本的思路是比較好理解的,今天根據它的特點寫了一個實例,我會把所有的數據和代碼都寫在下面供大家參考,不足之處,請指正。謝謝!
update:工程代碼全部在本頁面中,測試數據已丟失,建議去UCI Dataset中找一個自行測試一下。
幾點說明:
1.KNN中的K=5;
2.在計算權重時,采用的是減去函數{1,0.8,0.6,0.4,0.2},當然你也可以采用反函數或高斯函數;
3.5%作為測試集(decision.txt),95%作為訓練集(training.txt);
4.在計算costfun之前,對所有的屬性進行了歸一化,由於這里不知道數據集每個屬性代表的含義,所以就一視同仁,實際情況下,應該具體問題具體分析;
XBWKNN.java
package XBWKNN; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; /** * KNN算法 * @author XBW * @date 2014年8月16日 */ public class XBWKNN{ public final static int KofKNN=5; public final static double weight[]={1,0.9,0.7,0.4,0.1}; //減法函數y=1-0.2*x /** * knn * @param data * @param ds * @return ans */ public static int knn(Data data,DataSet ds){ int ans = 0; List<Data> dis=calcDis(data,ds); ans=calcKDis(data,dis); return ans; } /** * 計算訓練集中所有向量的距離,排序之后取前K個 * @param data * @param ds * @return */ @SuppressWarnings("null") public static List<Data>calcDis(Data data,DataSet ds){ List<Data> anslist =new ArrayList<Data>(); double dx1=data.x1; double dx2=data.x2; double dx3=data.x3; for(int i=0;i<ds.ds.size();i++){ double x1=ds.ds.get(i).x1; double x2=ds.ds.get(i).x2; double x3=ds.ds.get(i).x3; ds.ds.get(i).costfun=Math.sqrt((dx1-x1)*(dx1-x1)+(dx2-x2)*(dx2-x2)+(dx3-x3)*(dx3-x3)); anslist.add(ds.ds.get(i)); } Collections.sort(anslist,new Comparator<Data>(){ public int compare(Data o1, Data o2) { Double s=o1.costfun-o2.costfun; if(s<0) return -1; else return 1; } }); return anslist; } /** * 按一定的權重計算出前K個 * @param data * @param ds * @return */ public static int calcKDis(Data data,List<Data> anslist){ Double[] anstype={0.0,0.0,0.0,0.0}; for(int i=0;i<KofKNN;i++){ if(anslist.get(i).type==1){ anstype[1]+=weight[i]; } else if(anslist.get(i).type==2){ anstype[2]+=weight[i]; } if(anslist.get(i).type==3){ anstype[3]+=weight[i]; } } Double maxt=-1.0; int tag=1; for(int i=1;i<=3;i++){ if(maxt<anstype[i]){ tag=i; maxt=anstype[i]; } } return tag; } public static void main(String[] args) throws IOException{ DataSet ds=new DataSet(); DataTest dt=new DataTest(); int correct=0; for(int i=0;i<dt.dt.size();i++){ Data data=dt.dt.get(i); int result=knn(data,ds); if(result==data.type){ correct++; } } System.out.println("total test num :"+dt.dt.size()); System.out.println("correct test num :"+correct); System.out.println("ratio :"+correct/(double)dt.dt.size()); } }
Datatest.java
package XBWKNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; /** * 測試數據 * @author XBW * @date 2014年8月16日 */ public class DataTest{ String defaultpath="D:\\MachineLearning\\十大算法\\KNN\\knncode\\decision.txt"; List<Data> dt; @SuppressWarnings("null") public DataTest() throws IOException{ List<Data> dset = new ArrayList<Data>(); File ds=new File(defaultpath); @SuppressWarnings("resource") BufferedReader br = new BufferedReader(new FileReader(ds)); String tsing; double max1=-1; double max2=-1; double max3=-1; while((tsing=br.readLine())!=null){ String[] dlist=tsing.split(" "); Data data=new Data(); data.x1=Double.parseDouble(dlist[0]); data.x2=Double.parseDouble(dlist[1]); data.x3=Double.parseDouble(dlist[2]); data.type=Integer.parseInt(dlist[3]); dset.add(data); if(data.x1>max1){ max1=data.x1; } if(data.x2>max2){ max2=data.x2; } if(data.x3>max3){ max3=data.x3; } } dset=normalization(dset,max1,max2,max3); this.dt=dset; } public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){ for(int i=0;i<dset.size();i++){ dset.get(i).x1/=m1; dset.get(i).x2/=m2; dset.get(i).x3/=m3; } return dset; } }
DataSet.java
package XBWKNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; /** * 訓練數據 * @author XBW * @date 2014年8月16日 */ public class DataSet{ String defaultpath="D:\\MachineLearning\\十大算法\\KNN\\knncode\\training.txt"; List<Data> ds; @SuppressWarnings("null") public DataSet() throws IOException{ List<Data> dset =new ArrayList<Data>(); File ds=new File(defaultpath); @SuppressWarnings("resource") BufferedReader br = new BufferedReader(new FileReader(ds)); String tsing; double max1=-1; double max2=-1; double max3=-1; while((tsing=br.readLine())!=null){ String[] dlist=tsing.split(" "); Data data=new Data(); data.x1=Double.parseDouble(dlist[0]); data.x2=Double.parseDouble(dlist[1]); data.x3=Double.parseDouble(dlist[2]); data.type=Integer.parseInt(dlist[3]); dset.add(data); if(data.x1>max1){ max1=data.x1; } if(data.x2>max2){ max2=data.x2; } if(data.x3>max3){ max3=data.x3; } } dset=normalization(dset,max1,max2,max3); this.ds=dset; } public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){ for(int i=0;i<dset.size();i++){ dset.get(i).x1/=m1; dset.get(i).x2/=m2; dset.get(i).x3/=m3; } return dset; } }
Data.java
package XBWKNN; /** * 一條數據 * @author XBW * @date 2014年8月16日 */ public class Data{ Double x1; Double x2; Double x3; Double costfun; int type; }
output:


