決策樹的分類過程和人的決策過程比較相似,就是先挑“權重”最大的那個考慮,然后再往下細分。比如你去看醫生,症狀是流鼻涕,咳嗽等,那么醫生就會根據你的流鼻涕這個權重最大的症狀先認為你是感冒,接着再根據你咳嗽等症狀細分你是否為病毒性感冒等等。決策樹的過程其實也是基於極大似然估計。那么我們用一個什么標准來衡量某個特征是權重最大的呢,這里有信息增益和基尼系數兩個。ID3算法采用的是信息增益這個量。
根據《統計學習方法》中的描述,G(D,A)表示數據集D在特征A的划分下的信息增益。具體公式:
G(D,A)=H(D)-H(D|A)。其中H(D)表示數據集D的熵,熵可以用來描述其混亂度,計算公式為
H(D)=可見對於數據集D而言,|Dk|表示類別為k的數量,其類別越多,越混亂。
而H(D|A)表示數據集D在A的划分下的的不確定性。他們的差也即是信息增益,表示由於特征A使得數據集D的分類的不確定減少的差,所以這個值越大說明A的分類對D越有效,也就是權重越大。
H(D|A)=|Dik|表示在特征A中value為i的划分下數據集類別為k的數量。
有了這兩個公式,接下來就可以寫代碼了。這里為了清晰的表示這個結果,采用了xml來輸出。由於剛開始學java所以只能即學即用(java和python簡直不能比,python寫ID3一百行代碼妥妥的搞定,java用了將近300行。。。)
算法步驟:
輸入:數據集D,特征集A(這里也可以輸入一個閥值,如果信息增益小於該閥值就直接作為葉節點,這樣可以避免過擬合)
輸出:xml文件
1 如果D中的類別是同一類,則作為葉節點,標記為該類Ck
2 如果特征集A中沒有特征了,那么作為葉節點,並且用數據集D中類別最多的類作為類標記
3 對D的各個特征求最大信息增益,選擇信息增益最大的特征Ag
4 對特征Ag中各個值ai繼續對數據集進行分割為Di
5 以Di為數據集,A-Ag為特征集為輸入進行1-4步驟
具體代碼:
1 import java.io.BufferedReader; 2 import java.io.FileInputStream; 3 import java.io.FileWriter; 4 import java.io.IOException; 5 import java.io.InputStreamReader; 6 import java.util.ArrayList; 7 import java.util.HashMap; 8 import java.util.HashSet; 9 import java.util.Map; 10 import java.util.Set; 11 12 import org.dom4j.Document; 13 import org.dom4j.DocumentHelper; 14 import org.dom4j.Element; 15 import org.dom4j.io.XMLWriter; 16 17 18 19 20 21 22 class Utils{ 23 //用於從文件中獲取數據集 24 public static ArrayList<ArrayList<String>> loadDataSet(String file) throws IOException{ 25 ArrayList<ArrayList<String>> dataSet=new ArrayList<ArrayList<String>>(); 26 FileInputStream fis=new FileInputStream(file); 27 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 28 BufferedReader br=new BufferedReader(isr); 29 String line=""; 30 line=br.readLine(); 31 while((line=br.readLine())!=null){ 32 String[] words=line.split(","); 33 ArrayList<String> data=new ArrayList<String>(); 34 for(int i=0;i<words.length;i++){ 35 data.add(words[i]); 36 } 37 dataSet.add(data); 38 } 39 br.close(); 40 isr.close(); 41 fis.close(); 42 return dataSet; 43 } 44 //用於從文件中獲取特征 45 public static ArrayList<String> loadFeature(String file) throws IOException{ 46 FileInputStream fis=new FileInputStream(file); 47 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 48 BufferedReader br=new BufferedReader(isr); 49 50 String[] line=br.readLine().split(","); 51 ArrayList<String> feature=new ArrayList<String>(); 52 for(int i=0;i<line.length-1;i++){ 53 feature.add(line[i]); 54 } 55 br.close(); 56 isr.close(); 57 fis.close(); 58 return feature; 59 } 60 //用於獲得數據集中的類別列表 61 public static ArrayList<String> getClassList(ArrayList<ArrayList<String>> dataSet){ 62 ArrayList<String> classList=new ArrayList<String>(); 63 int length=dataSet.get(0).size(); 64 for(ArrayList<String> data:dataSet){ 65 String label=data.get(length-1); 66 classList.add(label); 67 } 68 return classList; 69 } 70 //返回數據集中的特征數 71 public static int featureNum(ArrayList<ArrayList<String>> dataList){ 72 int len=dataList.get(0).size()-1; 73 return len; 74 } 75 76 77 // public static void writeToXML(String fileName) throws IOException{ 78 // Document document = DocumentHelper.createDocument(); 79 // Element root = document.addElement("DecisionTree"); 80 // Element outlook=root.addElement("outlook"); 81 // outlook.addAttribute("value","sunny"); 82 // Element humidity1=outlook.addElement("humidity"); 83 // humidity1.addAttribute("value","high"); 84 // humidity1.addText("no"); 85 // Element humidity2=outlook.addElement("humidity"); 86 // humidity2.addAttribute("value","normal"); 87 // humidity2.addText("yes"); 88 // 89 // XMLWriter writer=new XMLWriter(new FileWriter(fileName)); 90 // writer.write(document); 91 // writer.close(); 92 // } 93 //用於獲得數據集中第index列的map映射,方便后續的遍歷value和計算熵 94 public static Map<String,Integer> getSubMap(ArrayList<ArrayList<String>> dataSet,int index){ 95 int total=dataSet.size(); 96 Map<String,Integer> subMap=new HashMap(); 97 for(ArrayList<String> data:dataSet){ 98 String lable=data.get(index); 99 if(subMap.get(lable)==null){ 100 subMap.put(lable,1); 101 }else{ 102 subMap.put(lable,subMap.get(lable)+1); 103 } 104 } 105 return subMap; 106 } 107 //打印map,用於debug的時候 108 public static void showMap(Map<String,Integer> map){ 109 for(Map.Entry<String,Integer> entry:map.entrySet()){ 110 System.out.println(entry.getKey()+":"+entry.getValue()); 111 } 112 } 113 //求熵 114 public static double getEntropy(ArrayList<ArrayList<String>> dataSet,int index){ 115 int total=dataSet.size(); 116 Map<String,Integer> subMap=getSubMap(dataSet,index); 117 double entropy=0; 118 for(Map.Entry<String,Integer> entry:subMap.entrySet()){ 119 double temp=entry.getValue()*1.0/total; 120 entropy+=temp*(Math.log(temp)/Math.log(2)); 121 } 122 return -entropy; 123 } 124 //求信息增益最大的分割點 125 public static String bestFeatureSplit(ArrayList<ArrayList<String>> dataSet,ArrayList<String> featureList){ 126 int length=dataSet.get(0).size(); 127 double totalEntropy=getEntropy(dataSet,length-1); 128 129 130 131 int featureNum=dataSet.get(0).size()-1; 132 int index=-1; 133 double maxInfoGain=-1; 134 for(int i=0;i<featureNum;i++){ 135 double entropy=getEntropy(dataSet,i); 136 Map<String,Integer> map=getSubMap(dataSet,i);//獲得該特征下的map 137 ArrayList<String> lableList=new ArrayList<String>(); 138 double entropySum=0; 139 140 for(Map.Entry<String,Integer> entry:map.entrySet()){//這里的Di就是map中的特征的value值 141 Map<String,Integer> subMap=new HashMap(); 142 143 144 for(ArrayList<String> data:dataSet){ 145 if(data.get(i).compareTo(entry.getKey())==0){ 146 if(subMap.get(data.get(length-1))==null){ 147 148 subMap.put(data.get(length-1),1); 149 }else{ 150 subMap.put(data.get(length-1),subMap.get(data.get(length-1))+1); 151 } 152 } 153 } 154 double x=0; 155 for(Map.Entry<String,Integer> subEntry:subMap.entrySet()){ 156 double temp=subEntry.getValue()*1.0/entry.getValue(); 157 x+=temp*(Math.log(temp)/Math.log(2)); 158 } 159 160 entropySum+=-x*(entry.getValue())/dataSet.size(); 161 } 162 entropySum=totalEntropy-entropySum; 163 if(entropySum>maxInfoGain){ 164 index=i; 165 maxInfoGain=entropySum; 166 } 167 } 168 return featureList.get(index); 169 } 170 //分割數據集,index為特征的下標 171 public static ArrayList<ArrayList<String>> splitDataSet(ArrayList<ArrayList<String>> dataSet,int index,String value){ 172 ArrayList<ArrayList<String>> subDataSet=new ArrayList<ArrayList<String>>(); 173 for(ArrayList<String> data:dataSet){ 174 if(data.get(index).compareTo(value)==0){ 175 ArrayList<String> temp=new ArrayList<String>(); 176 for(int i=0;i<data.size();i++){ 177 if(i!=index){ 178 temp.add(data.get(i)); 179 } 180 } 181 subDataSet.add(temp); 182 } 183 } 184 return subDataSet; 185 } 186 //list-》map 187 public static Map<String,Integer> arrayToMap(ArrayList<String> list){ 188 Map<String,Integer> map=new HashMap(); 189 for(String word:list){ 190 if(map.get(word)==null){ 191 map.put(word,1); 192 }else{ 193 map.put(word,map.get(word)+1); 194 } 195 } 196 return map; 197 } 198 //求label中某個數量最多的類別 199 public static String major(ArrayList<String> labelList){ 200 Map<String,Integer> map=arrayToMap(labelList); 201 int max=0; 202 String label=""; 203 for(Map.Entry<String,Integer> entry:map.entrySet()){ 204 if(entry.getValue()>max){ 205 label=entry.getKey(); 206 } 207 } 208 return label; 209 } 210 211 public static Set<String> getValueFromDataSet(ArrayList<ArrayList<String>> dataSet,int index){ 212 ArrayList<String> values=new ArrayList<String>(); 213 for(ArrayList<String> data:dataSet){ 214 try{ 215 values.add(data.get(index)); 216 }catch(Exception e){ 217 218 System.out.println("index is "+index); 219 } 220 } 221 Set<String> set=new HashSet(); 222 for(String value:values){ 223 set.add(value); 224 } 225 return set; 226 } 227 228 public static ArrayList<String> copyArrayList(ArrayList<String> src){ 229 ArrayList<String> dest=new ArrayList<String>(); 230 for(String s:src){ 231 dest.add(s); 232 } 233 return dest; 234 } 235 236 237 public static void showArrayList(ArrayList<ArrayList<String>> dataSet){ 238 for(ArrayList<String> data:dataSet){ 239 System.out.println(data); 240 } 241 } 242 243 } 244 245 246 public class DecisionTree { 247 248 249 public static int createTree(ArrayList<ArrayList<String>> dataSet,ArrayList<String> featureList,Element e){ 250 ArrayList<String> labelList=Utils.getClassList(dataSet);//獲取數據集中label的列表 251 if(Utils.arrayToMap(labelList).size()==1){//表示label中只有一種類別,所以此時不需要再分類了 252 e.addText(labelList.get(0)); 253 return 1; 254 } 255 if(dataSet.get(0).size()==1){//表示此時已經沒有特征了,所以也不需要再繼續了,此時以label中最多的類別為該節點的類別 256 e.addText(Utils.major(labelList)); 257 return 1; 258 } 259 260 ArrayList<String> subFeatureList=Utils.copyArrayList(featureList); 261 262 263 264 String feature=Utils.bestFeatureSplit(dataSet,featureList); 265 subFeatureList.remove(feature); 266 int index=featureList.indexOf(feature); 267 268 Set<String> valueSet=Utils.getValueFromDataSet(dataSet,index); 269 // Element next=e.addElement(feature);//原來的代碼位置 270 for(String value:valueSet){ 271 Element next=e.addElement(feature);//后來放到這里之后,xml的輸出就正確了,原因在於每遞歸一次就需要創建一個element,所以應該在for內創建。 272 next.addAttribute("value",value); 273 ArrayList<ArrayList<String>> subDataSet=Utils.splitDataSet(dataSet,index,value); 274 createTree(subDataSet,subFeatureList,next); 275 } 276 return 1; 277 } 278 279 public static void main(String[] args) throws IOException { 280 // TODO Auto-generated method stub 281 String file="C:/Users/Administrator/Desktop/upload/DT.txt"; 282 String xml="C:/Users/Administrator/Desktop/upload/DT1.xml"; 283 ArrayList<ArrayList<String>> dataSet=Utils.loadDataSet(file); 284 ArrayList<String> featureList=Utils.loadFeature(file); 285 Document document = DocumentHelper.createDocument(); 286 Element root = document.addElement("DecisionTree"); 287 createTree(dataSet,featureList,root); 288 XMLWriter writer=new XMLWriter(new FileWriter(xml)); 289 writer.write(document); 290 writer.close(); 291 System.out.println("finished"); 292 } 293 294 }
這次除了算法上的理解更加深刻了外,在java上也學到了些關於xml解析,讀寫等方法。
另外對遞歸的使用也更加形象些,對於遞歸一個容易錯的點就是函數上的參數,一定要認真對待,要清楚該參數該在什么時候初始化,什么時候被用到。我一開始在第269行上就出現錯誤了,一開始沒有考慮清楚這個next該在什么時候分配,后來發現每次創建節點的時候我們在xml就要創建一個相應的節點用來描述他,所以應該是在for循環里面創建,如果在for外面創建就表示,該特征下的所有值都只有一個element。
當然對於set,map的遍歷啥的也更加清晰了。