AdaBoost的java實現


目前學了幾個ML的分類的經典算法,但是一直想着是否有一種能將這些算法集成起來的,今天看到了AdaBoost,也算是半個集成,感覺這個思路挺好,很像人的訓練過程,並且對決策樹是一個很好的補充,因為決策樹容易過擬合,用AdaBoost可以讓一棵很深的決策樹將其分開成多棵矮樹,后來發現原來這個想法和random forest比較相似,RF的代碼等下周有空的時候可以寫一下。

這個貌似挺厲害的,看那些專門搞學術的人說是一篇很牛逼的論文證明說可以把弱學習提升到強學習。我這種搞工程的,能知道他的原理,適用范圍,能自己寫一遍代碼,感覺還是比那些讀幾遍論文只能惶惶其談的要安心些。

關於AdaBoost的基本概念,通過《機器學習方法》來概要的說下。

bagging和boosting的區別
bagging:是指在原始數據上通過放回抽樣,抽出和原始數據大小相等的新數據集(這個性質說明新數據集存在重復的值,而原始數據部分數據值不會出現在新數據集中),並重復該過程選擇N個新數據集,這樣通過N個分類器對這個N個數據集進行分類,最后選擇分類器投票結果中最多類別作為最后的分類結果。
boosting:相比bagging,boosting像是一種串行,bagging是一種並行的,bagging可以對於N個數據集通過N個分類器同時進行分類,並且每個分類器的權重是一樣的,但是boosting則相反,boosting是利用一個數據集依次由每個分類器進行分類,而確定每個分類器的權重是加大正確率高的分類器的權重,減少正確率低的分類器的權重。同時為了提高准確率,每次會降低被正確分類的樣本的權重,提高沒有正確分類的樣本的權重。這樣做其實比較符合人的決策過程,就是要多訓練自己容易做錯的題型,並且要多聽取正確性高的老師的意見。
 
那么AdaBoost的主要的兩個過程就是提高錯誤分類的樣本權重和提高正確率高的分類器的權重。
算法的步驟:
輸入:訓練集T,弱學習分類器(這里是一個節點的決策樹)
輸出:最終的分類器G
1 先初始化樣本權重值,D1={W11...W1n}W1i=1/n
2 根據樣本權重D1以及決策樹求分類誤差率,並求的最小的誤差率em,以及該決策樹
  em=
3 計算該分類器的權重
   可以看出,誤差率越小的,其權重越大
4 更新各個樣本的權重,Dm+1,(用公式編輯器好麻煩。。。 )
  
其中Zm是規范化銀子:
  
5 構建基本分類器
  F(X)=
6 計算該分類器下的誤差率,如果小於某個閾值就停止,否則從第二步開始迭代
 
終於不用打公式了。。。。
附上代碼:
  1 import java.io.BufferedReader;
  2 import java.io.FileInputStream;
  3 import java.io.IOException;
  4 import java.io.InputStreamReader;
  5 import java.util.ArrayList;
  6 
  7 class Stump{
  8     public int dim;
  9     public double thresh;
 10     public String condition;
 11     public double error;
 12     public ArrayList<Integer> labelList;
 13     double factor;
 14     
 15     public String toString(){
 16         return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
 17     }
 18 }
 19 
 20 class Utils{
 21     //加載數據集
 22     public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
 23         ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
 24         FileInputStream fis=new FileInputStream(filename);
 25         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
 26         BufferedReader br=new BufferedReader(isr);
 27         String line="";
 28         
 29         while((line=br.readLine())!=null){
 30             ArrayList<Double> data=new ArrayList<Double>();
 31             String[] s=line.split(" ");
 32             
 33             for(int i=0;i<s.length-1;i++){
 34                 data.add(Double.parseDouble(s[i]));
 35             }
 36             dataSet.add(data);
 37         }
 38         return  dataSet;
 39     }
 40     
 41     //加載類別
 42     public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
 43         ArrayList<Integer> labelSet=new ArrayList<Integer>();
 44         
 45         FileInputStream fis=new FileInputStream(filename);
 46         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
 47         BufferedReader br=new BufferedReader(isr);
 48         String line="";
 49         
 50         while((line=br.readLine())!=null){
 51             String[] s=line.split(" ");
 52             labelSet.add(Integer.parseInt(s[s.length-1]));
 53         }
 54         return labelSet;
 55     }
 56     //測試用的
 57     public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
 58         for(ArrayList<Double> data:dataSet){
 59             System.out.println(data);
 60         }
 61     }
 62     //獲取最大值,用於求步長
 63     public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
 64         double max=-9999.0;
 65         for(ArrayList<Double> data:dataSet){
 66             if(data.get(index)>max){
 67                 max=data.get(index);
 68             }
 69         }
 70         return max;
 71     }
 72     //獲取最小值,用於求步長
 73     public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
 74         double min=9999.0;
 75         for(ArrayList<Double> data:dataSet){
 76             if(data.get(index)<min){
 77                 min=data.get(index);
 78             }
 79         }
 80         return min;
 81     }
 82     
 83     //獲取數據集中以該feature為特征,以thresh和conditions為value的葉子節點的決策樹進行划分后得到的預測類別
 84     public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
 85         ArrayList<Integer> labelList=new ArrayList<Integer>();
 86         if(condition.compareTo("lt")==0){
 87             for(ArrayList<Double> data:dataSet){
 88                 if(data.get(feature)<=thresh){
 89                     labelList.add(1);
 90                 }else{
 91                     labelList.add(-1);
 92                 }
 93             }
 94         }else{
 95             for(ArrayList<Double> data:dataSet){
 96                 if(data.get(feature)>=thresh){
 97                     labelList.add(1);
 98                 }else{
 99                     labelList.add(-1);
100                 }
101             }
102         }
103         return labelList;
104     }
105     //求預測類別與真實類別的加權誤差
106     public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
107         double error=0;
108         
109         int n=real.size();
110 
111         for(int i=0;i<fake.size();i++){
112             if(fake.get(i)!=real.get(i)){
113                 error+=weights.get(i);
114 
115             }
116         }
117         
118         return error;
119     }
120     //構造一棵單節點的決策樹,用一個Stump類來存儲這些基本信息。
121     public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
122         int featureNum=dataSet.get(0).size();
123         
124         int rowNum=dataSet.size();
125         Stump stump=new Stump();
126         double minError=999.0;
127         System.out.println("第"+n+"次迭代");
128         for(int i=0;i<featureNum;i++){
129             double min=getMin(dataSet,i);
130             double max=getMax(dataSet,i);    
131             double step=(max-min)/(rowNum);
132             for(double j=min-step;j<=max+step;j=j+step){
133                 String[] conditions={"lt","gt"};//如果是lt,表示如果小於閥值則為真類,如果是gt,表示如果大於閥值則為正類
134                 for(String condition:conditions){
135                     ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition);
136                     
137                     double error=Utils.getError(labelList,labelSet,weights);
138                     if(error<minError){
139                         minError=error;
140                         stump.dim=i;
141                         stump.thresh=j;
142                         stump.condition=condition;
143                         stump.error=minError;
144                         stump.labelList=labelList;
145                         stump.factor=0.5*(Math.log((1-error)/error));
146                     }
147                     
148                 }
149             }
150             
151         }
152         
153         return stump;
154     }
155     
156     public static ArrayList<Double> getInitWeights(int n){
157         double weight=1.0/n;
158         ArrayList<Double> weights=new ArrayList<Double>();
159         for(int i=0;i<n;i++){
160             weights.add(weight);
161         }
162         return weights;
163     }
164     //更新樣本權值
165     public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
166         double Z=0;
167         ArrayList<Double> newWeights=new ArrayList<Double>();
168         int row=labelList.size();
169         double e=Math.E;
170         double factor=stump.factor;
171         for(int i=0;i<row;i++){
172             Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
173         }
174         
175         
176         for(int i=0;i<row;i++){
177             double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
178             newWeights.add(weight);
179         }
180         return newWeights;
181     }
182     //對加權誤差累加
183     public static ArrayList<Double> InitAccWeightError(int n){
184         ArrayList<Double> accError=new ArrayList<Double>();
185         for(int i=0;i<n;i++){
186             accError.add(0.0);
187         }
188         return accError;
189     }
190     
191     public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
192         ArrayList<Integer> t=stump.labelList;
193         double factor=stump.factor;
194         ArrayList<Double> newAccError=new ArrayList<Double>();
195         for(int i=0;i<t.size();i++){
196             double a=accerror.get(i)+factor*t.get(i);
197             newAccError.add(a);
198         }
199         return newAccError;
200     }
201     
202     public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
203         ArrayList<Integer> a=new ArrayList<Integer>();
204         int wrong=0;
205         for(int i=0;i<accError.size();i++){
206             if(accError.get(i)>0){
207                 if(labelList.get(i)==-1){
208                     wrong++;
209                 }
210             }else if(labelList.get(i)==1){
211                 wrong++;
212             }
213         }
214         double error=wrong*1.0/accError.size();
215         return error;
216     }
217     
218     public static void showStumpList(ArrayList<Stump> G){
219         for(Stump s:G){
220             System.out.println(s);
221             System.out.println(" ");
222         }
223     }
224 }
225 
226 
227 public class Adaboost {
228 
229     /**
230      * @param args
231      * @throws IOException 
232      */
233     
234     public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
235         int row=labelList.size();
236         ArrayList<Double> weights=Utils.getInitWeights(row);
237         ArrayList<Stump> G=new ArrayList<Stump>();
238         ArrayList<Double> accError=Utils.InitAccWeightError(row);
239         int n=1;
240         while(true){
241             Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵誤差率最小的單節點決策樹
242             G.add(stump);
243             weights=Utils.updateWeights(stump,labelList,weights);//更新權值
244             accError=Utils.accWeightError(accError,stump);//將加權誤差累加,因為這樣不用再利用分類器再求了
245             double error=Utils.calErrorRate(accError,labelList);
246             if(error<0.001){
247                 break;
248             }
249             n++;
250         }
251         return G;
252     }
253     
254     public static void main(String[] args) throws IOException {
255         // TODO Auto-generated method stub
256         String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
257         ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
258         ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
259         ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
260         Utils.showStumpList(G);
261         System.out.println("finished");
262     }
263 
264 }

這里的數據采用的是統計學習方法中的數據

0 1
1 1
2 1
3 -1
4 -1
5 -1
6 1
7 1
8 1
9 -1

這里是單個特征的,也可以是多維數據,例如

1.0 2.1 1
2.0 1.1 1
1.3 1.0 -1
1.0 1.0 -1
2.0 1.0 1

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM