目前學了幾個ML的分類的經典算法,但是一直想着是否有一種能將這些算法集成起來的,今天看到了AdaBoost,也算是半個集成,感覺這個思路挺好,很像人的訓練過程,並且對決策樹是一個很好的補充,因為決策樹容易過擬合,用AdaBoost可以讓一棵很深的決策樹將其分開成多棵矮樹,后來發現原來這個想法和random forest比較相似,RF的代碼等下周有空的時候可以寫一下。
這個貌似挺厲害的,看那些專門搞學術的人說是一篇很牛逼的論文證明說可以把弱學習提升到強學習。我這種搞工程的,能知道他的原理,適用范圍,能自己寫一遍代碼,感覺還是比那些讀幾遍論文只能惶惶其談的要安心些。
關於AdaBoost的基本概念,通過《機器學習方法》來概要的說下。
bagging和boosting的區別
可以看出,誤差率越小的,其權重越大
bagging:是指在原始數據上通過放回抽樣,抽出和原始數據大小相等的新數據集(這個性質說明新數據集存在重復的值,而原始數據部分數據值不會出現在新數據集中),並重復該過程選擇N個新數據集,這樣通過N個分類器對這個N個數據集進行分類,最后選擇分類器投票結果中最多類別作為最后的分類結果。
boosting:相比bagging,boosting像是一種串行,bagging是一種並行的,bagging可以對於N個數據集通過N個分類器同時進行分類,並且每個分類器的權重是一樣的,但是boosting則相反,boosting是利用一個數據集依次由每個分類器進行分類,而確定每個分類器的權重是加大正確率高的分類器的權重,減少正確率低的分類器的權重。同時為了提高准確率,每次會降低被正確分類的樣本的權重,提高沒有正確分類的樣本的權重。這樣做其實比較符合人的決策過程,就是要多訓練自己容易做錯的題型,並且要多聽取正確性高的老師的意見。
那么AdaBoost的主要的兩個過程就是提高錯誤分類的樣本權重和提高正確率高的分類器的權重。
算法的步驟:
輸入:訓練集T,弱學習分類器(這里是一個節點的決策樹)
輸出:最終的分類器G
1 先初始化樣本權重值,D1={W11...W1n}W1i=1/n
2 根據樣本權重D1以及決策樹求分類誤差率,並求的最小的誤差率em,以及該決策樹
em=

3 計算該分類器的權重

4 更新各個樣本的權重,Dm+1,(用公式編輯器好麻煩。。。 )

其中Zm是規范化銀子:

5 構建基本分類器
F(X)=

6 計算該分類器下的誤差率,如果小於某個閾值就停止,否則從第二步開始迭代
終於不用打公式了。。。。
附上代碼:
1 import java.io.BufferedReader; 2 import java.io.FileInputStream; 3 import java.io.IOException; 4 import java.io.InputStreamReader; 5 import java.util.ArrayList; 6 7 class Stump{ 8 public int dim; 9 public double thresh; 10 public String condition; 11 public double error; 12 public ArrayList<Integer> labelList; 13 double factor; 14 15 public String toString(){ 16 return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList; 17 } 18 } 19 20 class Utils{ 21 //加載數據集 22 public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{ 23 ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>(); 24 FileInputStream fis=new FileInputStream(filename); 25 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 26 BufferedReader br=new BufferedReader(isr); 27 String line=""; 28 29 while((line=br.readLine())!=null){ 30 ArrayList<Double> data=new ArrayList<Double>(); 31 String[] s=line.split(" "); 32 33 for(int i=0;i<s.length-1;i++){ 34 data.add(Double.parseDouble(s[i])); 35 } 36 dataSet.add(data); 37 } 38 return dataSet; 39 } 40 41 //加載類別 42 public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{ 43 ArrayList<Integer> labelSet=new ArrayList<Integer>(); 44 45 FileInputStream fis=new FileInputStream(filename); 46 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 47 BufferedReader br=new BufferedReader(isr); 48 String line=""; 49 50 while((line=br.readLine())!=null){ 51 String[] s=line.split(" "); 52 labelSet.add(Integer.parseInt(s[s.length-1])); 53 } 54 return labelSet; 55 } 56 //測試用的 57 public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){ 58 for(ArrayList<Double> data:dataSet){ 59 System.out.println(data); 60 } 61 } 62 //獲取最大值,用於求步長 63 public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){ 64 double max=-9999.0; 65 for(ArrayList<Double> data:dataSet){ 66 if(data.get(index)>max){ 67 max=data.get(index); 68 } 69 } 70 return max; 71 } 72 //獲取最小值,用於求步長 73 public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){ 74 double min=9999.0; 75 for(ArrayList<Double> data:dataSet){ 76 if(data.get(index)<min){ 77 min=data.get(index); 78 } 79 } 80 return min; 81 } 82 83 //獲取數據集中以該feature為特征,以thresh和conditions為value的葉子節點的決策樹進行划分后得到的預測類別 84 public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){ 85 ArrayList<Integer> labelList=new ArrayList<Integer>(); 86 if(condition.compareTo("lt")==0){ 87 for(ArrayList<Double> data:dataSet){ 88 if(data.get(feature)<=thresh){ 89 labelList.add(1); 90 }else{ 91 labelList.add(-1); 92 } 93 } 94 }else{ 95 for(ArrayList<Double> data:dataSet){ 96 if(data.get(feature)>=thresh){ 97 labelList.add(1); 98 }else{ 99 labelList.add(-1); 100 } 101 } 102 } 103 return labelList; 104 } 105 //求預測類別與真實類別的加權誤差 106 public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){ 107 double error=0; 108 109 int n=real.size(); 110 111 for(int i=0;i<fake.size();i++){ 112 if(fake.get(i)!=real.get(i)){ 113 error+=weights.get(i); 114 115 } 116 } 117 118 return error; 119 } 120 //構造一棵單節點的決策樹,用一個Stump類來存儲這些基本信息。 121 public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){ 122 int featureNum=dataSet.get(0).size(); 123 124 int rowNum=dataSet.size(); 125 Stump stump=new Stump(); 126 double minError=999.0; 127 System.out.println("第"+n+"次迭代"); 128 for(int i=0;i<featureNum;i++){ 129 double min=getMin(dataSet,i); 130 double max=getMax(dataSet,i); 131 double step=(max-min)/(rowNum); 132 for(double j=min-step;j<=max+step;j=j+step){ 133 String[] conditions={"lt","gt"};//如果是lt,表示如果小於閥值則為真類,如果是gt,表示如果大於閥值則為正類 134 for(String condition:conditions){ 135 ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); 136 137 double error=Utils.getError(labelList,labelSet,weights); 138 if(error<minError){ 139 minError=error; 140 stump.dim=i; 141 stump.thresh=j; 142 stump.condition=condition; 143 stump.error=minError; 144 stump.labelList=labelList; 145 stump.factor=0.5*(Math.log((1-error)/error)); 146 } 147 148 } 149 } 150 151 } 152 153 return stump; 154 } 155 156 public static ArrayList<Double> getInitWeights(int n){ 157 double weight=1.0/n; 158 ArrayList<Double> weights=new ArrayList<Double>(); 159 for(int i=0;i<n;i++){ 160 weights.add(weight); 161 } 162 return weights; 163 } 164 //更新樣本權值 165 public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){ 166 double Z=0; 167 ArrayList<Double> newWeights=new ArrayList<Double>(); 168 int row=labelList.size(); 169 double e=Math.E; 170 double factor=stump.factor; 171 for(int i=0;i<row;i++){ 172 Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i)); 173 } 174 175 176 for(int i=0;i<row;i++){ 177 double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z; 178 newWeights.add(weight); 179 } 180 return newWeights; 181 } 182 //對加權誤差累加 183 public static ArrayList<Double> InitAccWeightError(int n){ 184 ArrayList<Double> accError=new ArrayList<Double>(); 185 for(int i=0;i<n;i++){ 186 accError.add(0.0); 187 } 188 return accError; 189 } 190 191 public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){ 192 ArrayList<Integer> t=stump.labelList; 193 double factor=stump.factor; 194 ArrayList<Double> newAccError=new ArrayList<Double>(); 195 for(int i=0;i<t.size();i++){ 196 double a=accerror.get(i)+factor*t.get(i); 197 newAccError.add(a); 198 } 199 return newAccError; 200 } 201 202 public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){ 203 ArrayList<Integer> a=new ArrayList<Integer>(); 204 int wrong=0; 205 for(int i=0;i<accError.size();i++){ 206 if(accError.get(i)>0){ 207 if(labelList.get(i)==-1){ 208 wrong++; 209 } 210 }else if(labelList.get(i)==1){ 211 wrong++; 212 } 213 } 214 double error=wrong*1.0/accError.size(); 215 return error; 216 } 217 218 public static void showStumpList(ArrayList<Stump> G){ 219 for(Stump s:G){ 220 System.out.println(s); 221 System.out.println(" "); 222 } 223 } 224 } 225 226 227 public class Adaboost { 228 229 /** 230 * @param args 231 * @throws IOException 232 */ 233 234 public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){ 235 int row=labelList.size(); 236 ArrayList<Double> weights=Utils.getInitWeights(row); 237 ArrayList<Stump> G=new ArrayList<Stump>(); 238 ArrayList<Double> accError=Utils.InitAccWeightError(row); 239 int n=1; 240 while(true){ 241 Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵誤差率最小的單節點決策樹 242 G.add(stump); 243 weights=Utils.updateWeights(stump,labelList,weights);//更新權值 244 accError=Utils.accWeightError(accError,stump);//將加權誤差累加,因為這樣不用再利用分類器再求了 245 double error=Utils.calErrorRate(accError,labelList); 246 if(error<0.001){ 247 break; 248 } 249 n++; 250 } 251 return G; 252 } 253 254 public static void main(String[] args) throws IOException { 255 // TODO Auto-generated method stub 256 String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt"; 257 ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file); 258 ArrayList<Integer> labelSet=Utils.loadLabelSet(file); 259 ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet); 260 Utils.showStumpList(G); 261 System.out.println("finished"); 262 } 263 264 }
這里的數據采用的是統計學習方法中的數據
0 1 1 1 2 1 3 -1 4 -1 5 -1 6 1 7 1 8 1 9 -1
這里是單個特征的,也可以是多維數據,例如
1.0 2.1 1 2.0 1.1 1 1.3 1.0 -1 1.0 1.0 -1 2.0 1.0 1