一、ID3算法簡單介紹
最早起源於《羅斯昆ID3在悉尼大學。他第一次提出的ID3 1975年在一本書、機器學習、研究所碩士論文。ID3是建立了概念學習系統(CLS)算法。ID3算法是一種基於決策樹的算法。決策樹由決策結點、分支和葉子組成。決策樹中最上面的結點為根節點,每個分支是一個新的決策結點,或者是樹的葉子。每個決策結點代表一個問題或決策,通常對應於待分類對象的屬性。每一個葉子節點代表一種可能的分類結果。沿決策樹從上到下遍歷的過程中,在每個結點都會遇到一個測試,對每個結點上問題的不同的測試輸出導致不同的分支,最后會到達一個葉子節點,這個過程就是利用決策樹進行分類的過程,利用若干個變量來判斷所屬的類別。
二、ID3算法基礎--信息論
ID3算法是一信息論為基礎,以信息熵和信息增益度為衡量指標,從而實現對數據的分類操作。下面給出一些信息論中的基本概念:
定義1:若存在n個相同概率的消息,則每個消息的概率p是1/n,一個消息傳遞的消息量為-Log2(1/n)
定義2:若有n個消息,其給定概率分布為P=(p1,p2,....pn),則由該分布傳遞的消息量稱為P的熵,記為I(P)=-p1*Log2(p1)-p2*Log2(p2)-...-pn*Log2(pn).
定義3:若一個記錄集合T根據類別屬性的值被分成互相獨立的類C1,C2.....Ck,那么識別T的一個所屬那個類型需要的信息量為Info(T)=I(P),其中P為C1,C2....Ck的概率分布,即P=(|C1|/|T|,|C2|/|T|,....|Ck|/|T|)。
定義4:若我們先根據非類別屬性X的值將T的值分成集合T1,T2,T3....Tn,則確定T中一個元素類的信息量可通過確定Ti的加權平均值來得到,即Info(Ti)的加權平均值為:Info(X,T)=(i=1 to n求和)((|Ti|/|T|)Info(Ti))
定義5:信息增益度是兩個信息量之間的差值,其中一個信息量是確定T的一個元素的信息量,另一個信息量是在得到一個確定屬性X的值后需要確定T一個元素的信息量,公式為:Gain(X,T) = Info(T) = Info(X,T).
ID3算法計算每個屬性的信息增益,並選擇具有最高增益的屬性作為給定集合的測試屬性。對被選擇的屬性創建一個節點,並記錄該節點的屬性標記,對該屬性的每一個值創建一個分支,並對分支進行迭代循環計算信息增益操作。
三、ID3算法步驟示例
下面給定一個ID3算法的示例:
RID | age | income | student | credit_rating | buy_compter |
1 | youth | high | no | fair | no |
2 | youth | high | no | excellent | no |
3 | middle_aged | high | no | fair | yes |
4 | senior | medium | no | fair | yes |
5 | senior | low | yes | fair | yes |
6 | senior | low | yes | excellent | no |
7 | middle_aged | low | yes | excellent | yes |
8 | youth | medium | no | fair | no |
9 | youth | low | yes | fair | yes |
10 | senior | medium | yes | fair | yes |
11 | youth | medium | yes | excellent | yes |
12 | middle_aged | medium | no | excellent | yes |
13 | middle_aged | high | yes | fair | yes |
14 | senior | medium | no | excellent | no |
總數據量是14條,參考屬性是age(youth[5], middle_aged[4], senior[5]),income(high[4], medium[6], low[4]), student(no[7], yes[7]), credit_rating(fair[8], excellent[6])。目標屬性是bug_computer(no[5], yes[9]),希望的結果是能夠得到一個根據age, income, student, credit_rating來推測出來buy_computer的值。假設初始數據集D,參考屬性列表A,下面給定計算步驟:
第一步:在數據集D中就目標屬性的信息熵: Info(buy_computer) = -(5/14)*log2(5/14)-(9/14)*log2(9/14)=0.94
第二步:在數據集D中就參考屬性列表A中的每一個屬性計算,在該屬性值確定的條件下,確定一個bug_computer的信息熵,也就是條件熵。
age屬性:youth(no[3],yes[2]),middle_aged(no[0],yes[4]),senior(no[2],yes[3]),先分別計算youth、middle_aged、senior的信息熵。
Infoage(bug_computer|youth) = -(3/5)*log2(3/5) - (2/5)*log2(2/5) = 0.971
Infoage(bug_computer|middle_aged) = -(4/4)*log2(4/4) - (0/5)*log2(0/5) = 0
Infoage(bug_computer|senior) = -(2/5)*log2(2/5) - (3/5)*log2(3/5) = 0.971
則Infoage(buy_computer) = 5/14*0.971 + 4/14 * 0+ 5/14 * 0.971 = 0.694
同理:Infoincome(buy_computer) = 0.911;Infostudent(buy_computer) = 0.789;Infocredit_rating(buy_computer) = 0.892.
第三步,計算信息增益度,該值如果越大,表示目標屬性在該參考屬性上失去的信息熵越多,那么該屬性就越應該在決策樹的上層。計算結果為:
Gain(age,bug_computer) = Info(buy_computer) - Infoage(buy_computer) = 0.94 - 0.694 = 0.246
Gain(income,bug_computer) = Info(buy_computer) - Infoicome(buy_computer) = 0.94 - 0.911 = 0.029
Gain(student,bug_computer) = Info(buy_computer) - Infostudent(buy_computer) = 0.94 - 0.789 = 0.151
Gain(credit_rating,bug_computer) = Info(buy_computer) - Infocredit_rating(buy_computer) = 0.94 - 0.892 = 0.048
第四步,選擇信息增益度最大的屬性作為當前節點,此時是age,根據age的不同取值將初始數據集D分隔成以下情況。
1. age為youth的時候,子數據集是D1:
RID | income | student | credit_rating | buy_computer |
1 | high | no | fair | no |
2 | high | no | excellent | no |
8 | medium | no | fair | no |
9 | low | yes | fair | yes |
11 | high | yes | excellent | yes |
2. age為middle_aged的時候,子數據集是D2:
RID | income | student | credit_rating | buy_computer |
3 | high | no | fair | yes |
7 | low | yes | excellent | yes |
12 | medium | no | excellent | yes |
13 | high | yes | fair | yes |
3. age為senior的時候,子數據集是D3:
RID | income | student | credit_rating | buy_computer |
4 | medium | no | fair | yes |
5 | low | yes | fair | yes |
6 | low | yes | excellent | no |
10 | medium | yes | fair | yes |
14 | medium | no | excellent | no |
第五步,將已經選擇的參考屬性(age)從參考屬性列表A中剔除,針對第四步中產生的子數據集Di使用處理后的參考屬性列表A,再從第一步迭代處理。迭代結束條件為:
- 當某種分類中,目標屬性只有一個值,如這里當age為middle_aged的時候。
- 當分到某類的時候,目標屬性所有值中,某個值的比例達到了閾值(人為控制),比如可以設為只要buy_computer中某個值達到90%以上,就可以結束迭代。
經過多次迭戈處理,最終會得到一個樹結構如下圖所示:
獲得規則是:
IF AGE=middle_aged, THEN buy_computer = yes
IF AGE = youth AND STUDENT = yes, THEN buy_computer = yes
IF AGE = youth AND STUDENT = no, THEN buy_computer = no
IF AGE = senior AND CREDIT_RATING = excellent, THEN buy_computer = no
IF AGE = senior AND CREDIT_RATING = fair, THEN buy_computer = yes
SO, If the instance are ("15", "youth", "medium", "yes", "fair"), the predicted value of buy_computer is "yes".
四、ID3算法程序實現
下面分別給出python和java兩種語言的ID3算法的實現:
Python程序:

1 # -*- coding: utf-8 -*- 2 3 4 class Node: 5 '''Represents a decision tree node. 6 7 ''' 8 def __init__(self, parent = None, dataset = None): 9 self.dataset = dataset # 落在該結點的訓練實例集 10 self.result = None # 結果類標簽 11 self.attr = None # 該結點的分裂屬性ID 12 self.childs = {} # 該結點的子樹列表,key-value pair: (屬性attr的值, 對應的子樹) 13 self.parent = parent # 該結點的父親結點 14 15 16 17 def entropy(props): 18 if (not isinstance(props, (tuple, list))): 19 return None 20 21 from math import log 22 log2 = lambda x:log(x)/log(2) # an anonymous function 23 e = 0.0 24 for p in props: 25 if p != 0: 26 e = e - p * log2(p) 27 return e 28 29 30 def info_gain(D, A, T = -1, return_ratio = False): 31 '''特征A對訓練數據集D的信息增益 g(D,A) 32 33 g(D,A)=entropy(D) - entropy(D|A) 34 假設數據集D的每個元組的最后一個特征為類標簽 35 T為目標屬性的ID,-1表示元組的最后一個元素為目標''' 36 if (not isinstance(D, (set, list))): 37 return None 38 if (not type(A) is int): 39 return None 40 C = {} # 類別計數字典 41 DA = {} # 特征A的取值計數字典 42 CDA = {} # 類別和特征A的不同組合的取值計數字典 43 for t in D: 44 C[t[T]] = C.get(t[T], 0) + 1 45 DA[t[A]] = DA.get(t[A], 0) + 1 46 CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1 47 48 PC = map(lambda x : 1.0 * x / len(D), C.values()) # 類別的概率列表,即目標屬性的概率,信息熵 49 entropy_D = entropy(tuple(PC)) # map返回的對象類型為map,需要強制類型轉換為元組 50 51 52 PCDA = {} # 特征A的每個取值給定的條件下各個類別的概率(條件概率) 53 for key, value in CDA.items(): 54 a = key[1] # 特征A 55 pca = value / DA[a] 56 PCDA.setdefault(a, []).append(pca) 57 58 condition_entropy = 0.0 59 for a, v in DA.items(): 60 p = v / len(D) 61 e = entropy(PCDA[a]) 62 condition_entropy += e * p 63 64 if (return_ratio): 65 return (entropy_D - condition_entropy) / entropy_D 66 else: 67 return entropy_D - condition_entropy 68 69 def get_result(D, T = -1): 70 '''獲取數據集D中實例數最大的目標特征T的值''' 71 if (not isinstance(D, (set, list))): 72 return None 73 if (not type(T) is int): 74 return None 75 count = {} 76 for t in D: 77 count[t[T]] = count.get(t[T], 0) + 1 78 max_count = 0 79 for key, value in count.items(): 80 if (value > max_count): 81 max_count = value 82 result = key 83 return result 84 85 86 def devide_set(D, A): 87 '''根據特征A的值把數據集D分裂為多個子集''' 88 if (not isinstance(D, (set, list))): 89 return None 90 if (not type(A) is int): 91 return None 92 subset = {} 93 for t in D: 94 subset.setdefault(t[A], []).append(t) 95 return subset 96 97 98 def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"): 99 '''根據數據集D和特征集A構建決策樹. 100 101 T為目標屬性在元組中的索引 . 目前支持ID3和C4.5兩種算法''' 102 if (Tree != None and not isinstance(Tree, Node)): 103 return None 104 if (not isinstance(D, (set, list))): 105 return None 106 if (not type(A) is set): 107 return None 108 109 if (None == Tree): 110 Tree = Node(None, D) 111 subset = devide_set(D, T) 112 if (len(subset) <= 1): 113 for key in subset.keys(): 114 Tree.result = key 115 del(subset) 116 return Tree 117 if (len(A) <= 0): 118 Tree.result = get_result(D) 119 return Tree 120 use_gain_ratio = False if algo == "ID3" else True 121 122 max_gain = 0 123 for a in A: 124 gain = info_gain(D, a, return_ratio = use_gain_ratio) 125 if (gain > max_gain): 126 max_gain = gain 127 attr_id = a # 獲取信息增益最大的特征 128 if (max_gain < threshold): 129 Tree.result = get_result(D) 130 return Tree 131 Tree.attr = attr_id 132 subD = devide_set(D, attr_id) 133 del(D[:]) # 刪除中間數據,釋放內存 134 Tree.dataset = None 135 A.discard(attr_id) # 從特征集中排查已經使用過的特征 136 for key in subD.keys(): 137 tree = Node(Tree, subD.get(key)) 138 Tree.childs[key] = tree 139 build_tree(subD.get(key), A, threshold, T, tree) 140 return Tree 141 142 143 def print_brance(brance, target): 144 odd = 0 145 for e in brance: 146 print e,('=' if odd == 0 else '∧'), 147 odd = 1 - odd 148 print "target =", target 149 150 151 def print_tree(Tree, stack = []): 152 if (None == Tree): 153 return 154 if (None != Tree.result): 155 print_brance(stack, Tree.result) 156 return 157 stack.append(Tree.attr) 158 for key, value in Tree.childs.items(): 159 stack.append(key) 160 print_tree(value, stack) 161 stack.pop() 162 stack.pop() 163 164 def classify(Tree, instance): 165 if (None == Tree): 166 return None 167 if (None != Tree.result): 168 return Tree.result 169 if instance[Tree.attr] in Tree.childs: 170 return classify(Tree.childs[instance[Tree.attr]], instance) 171 else: 172 return None 173 174 dataset = [ 175 ("青年", "否", "否", "一般", "否") 176 ,("青年", "否", "否", "好", "否") 177 ,("青年", "是", "否", "好", "是") 178 ,("青年", "是", "是", "一般", "是") 179 ,("青年", "否", "否", "一般", "否") 180 ,("中年", "否", "否", "一般", "否") 181 ,("中年", "否", "否", "好", "否") 182 ,("老年", "是", "否", "非常好", "是") 183 ,("老年", "否", "是", "一般", "是") 184 ,("老年", "否", "是", "一般", "是") 185 ,("老年", "否", "是", "一般", "是") 186 ,("老年", "否", "是", "好", "是") 187 ,("老年", "是", "否", "一般", "是") 188 ,("老年", "是", "否", "一般", "否") 189 ,("老年", "否", "否", "一般", "否") 190 ] 191 192 s = set(range(0, len(dataset[0]) - 1)) 193 s = set([0,1,3,4]) 194 T = build_tree(dataset, s) 195 print_tree(T) 196 print(classify(T, ("老年", "是", "否", "一般", "否"))) 197 print(classify(T, ("老年", "是", "否", "一般", "是"))) 198 print(classify(T, ("老年", "是", "是", "好", "否"))) 199 print(classify(T, ("青年", "是", "否", "好", "是"))) 200 print(classify(T, ("中年", "是", "否", "好", "否")))
該python程序的訓練集不是上面給定的這個列子,輸出結果為:
0 = 青年 ∧ 1 = 否 ∧ target = 否 0 = 青年 ∧ 1 = 是 ∧ target = 是 0 = 中年 ∧ target = 否 0 = 老年 ∧ 3 = 好 ∧ target = 是 0 = 老年 ∧ 3 = 非常好 ∧ target = 是 0 = 老年 ∧ 3 = 一般 ∧ 4 = 否 ∧ target = 否 0 = 老年 ∧ 3 = 一般 ∧ 4 = 是 ∧ target = 是 否 是 是 是 否 [Finished in 0.3s]
Java程序,該程序的數據集是上面給定的例子,代碼及結果如下:

1 2 3 import java.util.ArrayList; 4 import java.util.Collection; 5 import java.util.Deque; 6 import java.util.HashMap; 7 import java.util.LinkedList; 8 import java.util.List; 9 import java.util.Map; 10 11 public class ID3Tree { 12 private List<String[]> datas; 13 private List<Integer> attributes; 14 private double threshold = 0.0001; 15 private int targetIndex = 1; 16 private Node tree; 17 private Map<Integer, String> attributeMap; 18 19 protected ID3Tree() { 20 super(); 21 } 22 23 public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, int targetIndex) { 24 this(datas, attributes, attributeMap, 0.0001, targetIndex, null); 25 } 26 27 public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, double threshold, int targetIndex, Node tree) { 28 super(); 29 this.datas = datas; 30 this.attributes = attributes; 31 this.attributeMap = attributeMap; 32 this.threshold = threshold; 33 this.targetIndex = targetIndex; 34 this.tree = tree; 35 } 36 37 /** 38 * 節點對象 39 * 40 * @author Gerry.Liu 41 * 42 */ 43 class Node { 44 private List<String[]> dataset; // 落在該節點上的訓練實訓集 45 private String result; // 結果類標簽 46 private int attr; // 該節點的分裂屬性ID,下標 47 private Node parent; // 該節點的父節點 48 private Map<String, List<Node>> childs; // 該節點的子節點集合 49 50 public Node(List<String[]> datas, Node parent) { 51 this.dataset = datas; 52 this.parent = parent; 53 this.childs = new HashMap<>(); 54 } 55 } 56 57 class KeyValue { 58 private String first; 59 private String second; 60 61 public KeyValue(String first, String second) { 62 super(); 63 this.first = first; 64 this.second = second; 65 } 66 67 @Override 68 public int hashCode() { 69 final int prime = 31; 70 int result = 1; 71 result = prime * result + getOuterType().hashCode(); 72 result = prime * result + ((first == null) ? 0 : first.hashCode()); 73 result = prime * result + ((second == null) ? 0 : second.hashCode()); 74 return result; 75 } 76 77 @Override 78 public boolean equals(Object obj) { 79 if (this == obj) 80 return true; 81 if (obj == null) 82 return false; 83 if (getClass() != obj.getClass()) 84 return false; 85 KeyValue other = (KeyValue) obj; 86 if (!getOuterType().equals(other.getOuterType())) 87 return false; 88 if (first == null) { 89 if (other.first != null) 90 return false; 91 } else if (!first.equals(other.first)) 92 return false; 93 if (second == null) { 94 if (other.second != null) 95 return false; 96 } else if (!second.equals(other.second)) 97 return false; 98 return true; 99 } 100 101 private ID3Tree getOuterType() { 102 return ID3Tree.this; 103 } 104 } 105 106 /** 107 * 根據概率計算信息熵,計算規則是:<br/> 108 * entropy(p1,p2....pn) = -p1*log2(p1) -p2*log2(p2)-.....-pn*log2(pn) 109 * 110 * @param props 111 * @return 112 */ 113 private double entropy(List<Double> props) { 114 if (props == null || props.isEmpty()) { 115 return 0; 116 } else { 117 double result = 0; 118 for (double p : props) { 119 if (p > 0) { 120 result = result - p * Math.log(p) / Math.log(2); 121 } 122 } 123 return result; 124 } 125 } 126 127 /** 128 * 計算概率 129 * 130 * @param totalRecords 131 * @param counts 132 * @return 133 */ 134 private List<Double> calcProbability(int totalRecords, Collection<Integer> counts) { 135 if (totalRecords == 0 || counts == null || counts.isEmpty()) { 136 return null; 137 } 138 139 List<Double> result = new ArrayList<>(); 140 for (int count : counts) { 141 result.add(1.0 * count / totalRecords); 142 } 143 return result; 144 } 145 146 /** 147 * 獲取信息增益Gain(datas,attribute)<br/> 148 * 特征屬性attribute(A)對訓練數據集datas(D)的信息增益<br/> 149 * g(D,A) = entropy(D) - entropy(D|A)<br/> 150 * 151 * @param datas 152 * 訓練數據集 153 * @param attributeIndex 154 * 特征屬性下標 155 * @param targetAttributeIndex 156 * 目標屬性下標 157 * @return 158 */ 159 private double infoGain(List<String[]> datas, int attributeIndex, int targetAttributeIndex) { 160 if (datas == null || datas.isEmpty()) { 161 return 0; 162 } 163 164 Map<String, Integer> targetAttributeCountMap = new HashMap<String, Integer>(); // 類別(目標屬性)計數 165 Map<String, Integer> featureAttributesCountMap = new HashMap<>(); // 特征屬性上的取值計數 166 Map<KeyValue, Integer> tfAttributeCountMap = new HashMap<>(); // 類別和特征屬性的不同組合的計數 167 168 for (String[] arrs : datas) { 169 String tv = arrs[targetAttributeIndex]; 170 String fv = arrs[attributeIndex]; 171 if (targetAttributeCountMap.containsKey(tv)) { 172 targetAttributeCountMap.put(tv, targetAttributeCountMap.get(tv) + 1); 173 } else { 174 targetAttributeCountMap.put(tv, 1); 175 } 176 if (featureAttributesCountMap.containsKey(fv)) { 177 featureAttributesCountMap.put(fv, featureAttributesCountMap.get(fv) + 1); 178 } else { 179 featureAttributesCountMap.put(fv, 1); 180 } 181 KeyValue key = new KeyValue(fv, tv); 182 if (tfAttributeCountMap.containsKey(key)) { 183 tfAttributeCountMap.put(key, tfAttributeCountMap.get(key) + 1); 184 } else { 185 tfAttributeCountMap.put(key, 1); 186 } 187 } 188 189 int totalDataSize = datas.size(); 190 // 計算概率 191 List<Double> probabilitys = calcProbability(totalDataSize, targetAttributeCountMap.values()); 192 // 計算目標屬性的信息熵 193 double entropyDatas = this.entropy(probabilitys); 194 195 // 計算條件概率 196 // 第一步,計算目標屬性的各種取值,在特征屬性確定的條件下的情況 197 Map<String, List<Double>> pcda = new HashMap<>(); 198 for (Map.Entry<KeyValue, Integer> entry : tfAttributeCountMap.entrySet()) { 199 String key = entry.getKey().first; 200 double pca = 1.0 * entry.getValue() / featureAttributesCountMap.get(key); 201 if (pcda.containsKey(key)) { 202 pcda.get(key).add(pca); 203 } else { 204 List<Double> list = new ArrayList<Double>(); 205 list.add(pca); 206 pcda.put(key, list); 207 } 208 } 209 // 第二步,針對每個特征屬性的值取信息熵,並獲取平均熵 210 double conditionEntropy = 0.0; 211 for (Map.Entry<String, Integer> entry : featureAttributesCountMap.entrySet()) { 212 double p = 1.0 * entry.getValue() / totalDataSize; 213 double e = this.entropy(pcda.get(entry.getKey())); 214 conditionEntropy += e * p; 215 } 216 return entropyDatas - conditionEntropy; 217 } 218 219 /** 220 * 獲取數據集中目標屬性中,實例值個數最大的目標特征值 221 * 222 * @param datas 223 * @param targetAttributeIndex 224 * @return 225 */ 226 private String getResult(List<String[]> datas, int targetAttributeIndex) { 227 if (datas == null || datas.isEmpty()) { 228 return null; 229 } else { 230 String result = ""; 231 Map<String, Integer> countMap = new HashMap<>(); 232 for (String[] arr : datas) { 233 String key = arr[targetAttributeIndex]; 234 if (countMap.containsKey(key)) { 235 countMap.put(key, countMap.get(key) + 1); 236 } else { 237 countMap.put(key, 1); 238 } 239 } 240 241 int maxCount = -1; 242 for (Map.Entry<String, Integer> entry : countMap.entrySet()) { 243 if (entry.getValue() > maxCount) { 244 maxCount = entry.getValue(); 245 result = entry.getKey(); 246 } 247 } 248 return result; 249 } 250 } 251 252 /** 253 * 按照特征屬性的值將數據集D分裂成為多個子集 254 * 255 * @param datas 256 * 數據集 257 * @param attributeIndex 258 * 特征屬性下標 259 * @return 260 */ 261 private Map<String, List<String[]>> devideDatas(List<String[]> datas, int attributeIndex) { 262 Map<String, List<String[]>> subdatas = new HashMap<>(); 263 if (datas != null && !datas.isEmpty()) { 264 for (String[] arr : datas) { 265 String key = arr[attributeIndex]; 266 if (subdatas.containsKey(key)) { 267 subdatas.get(key).add(arr); 268 } else { 269 List<String[]> list = new ArrayList<>(); 270 list.add(arr); 271 subdatas.put(key, list); 272 } 273 } 274 } 275 return subdatas; 276 } 277 278 /** 279 * 打印決策樹 280 * 281 * @param tree 282 * @param stock 283 */ 284 private void printTree(Node tree, Deque<Object> stock) { 285 if (tree == null) { 286 return; 287 } 288 289 if (tree.result != null) { 290 this.printBrance(stock, tree.result); 291 } else { 292 stock.push(this.attributeMap.get(tree.attr)); 293 for (Map.Entry<String, List<Node>> entry : tree.childs.entrySet()) { 294 stock.push(entry.getKey()); 295 for (Node node : entry.getValue()) { 296 this.printTree(node, stock); 297 } 298 stock.pop(); 299 } 300 stock.pop(); 301 } 302 } 303 304 /** 305 * 輸出Node表示的決策樹的規則 306 * 307 * @param tree 308 */ 309 private void printBrance(Deque<Object> stock, String target) { 310 StringBuffer sb = new StringBuffer(); 311 int odd = 0; 312 for (Object e : stock) { 313 sb.insert(0, odd == 0 ? "^" : "=").insert(0, e); 314 // sb.append(e).append(odd == 0 ? "=" : "^"); 315 odd = 1 - odd; 316 } 317 sb.append("target=").append(target); 318 System.out.println(sb.toString()); 319 } 320 321 /** 322 * 創建一個決策樹 323 * 324 * @param datas 325 * @param attributes 326 * @param threshold 327 * @param targetIndex 328 * @param tree 329 * @return 330 */ 331 private Node buildTree(List<String[]> datas, List<Integer> attributes, double threshold, int targetIndex, Node tree) { 332 if (tree == null) { 333 tree = new Node(datas, null); 334 } 335 // 分隔數據集,返回的數據集為empty或者是有數據,不會為null 336 Map<String, List<String[]>> subDatas = this.devideDatas(datas, targetIndex); 337 if (subDatas.size() <= 1) { 338 // 這里只會有一個key 339 for (String key : subDatas.keySet()) { 340 tree.result = key; 341 } 342 } else if (attributes == null || attributes.size() < 1) { 343 // 沒有特征集,那么直接獲取最多的值 344 tree.result = this.getResult(datas, targetIndex); 345 } else { 346 double maxGain = 0; 347 int attr = 0; 348 349 for (int attribute : attributes) { 350 double gain = this.infoGain(datas, attribute, targetIndex); 351 if (gain > maxGain) { 352 maxGain = gain; 353 attr = attribute;// 最大的信息增益下標 354 } 355 } 356 357 if (maxGain < threshold) { 358 // 達到收益條件 359 tree.result = this.getResult(datas, targetIndex); 360 } else { 361 // 沒有達到結束條件,繼續進行 362 tree.attr = attr; 363 subDatas = this.devideDatas(datas, attr); 364 tree.dataset = null; 365 attributes.remove(Integer.valueOf(attr)); 366 for (String key : subDatas.keySet()) { 367 Node childTree = new Node(subDatas.get(key), tree); 368 if (tree.childs.containsKey(key)) { 369 tree.childs.get(key).add(childTree); 370 } else { 371 List<Node> childs = new ArrayList<>(); 372 childs.add(childTree); 373 tree.childs.put(key, childs); 374 } 375 this.buildTree(subDatas.get(key), attributes, threshold, targetIndex, childTree); 376 } 377 } 378 } 379 return tree; 380 } 381 382 /** 383 * 根據決策規則獲取推薦值 384 * 385 * @param instance 386 * @return 387 */ 388 private String classify(Node tree, String[] instance) { 389 if (tree == null) { 390 return null; 391 } 392 if (tree.result != null) { 393 return tree.result; 394 } 395 if (tree.childs.containsKey(instance[tree.attr])) { 396 List<Node> nodes = tree.childs.get(instance[tree.attr]); 397 for (Node node : nodes) { 398 return this.classify(node, instance); 399 } 400 } 401 return null; 402 } 403 404 /** 405 * 生產決策樹 406 */ 407 public void buildTree() { 408 this.tree = new Node(this.datas, null); 409 this.buildTree(datas, attributes, threshold, targetIndex, tree); 410 } 411 412 /** 413 * 打印生產的規則 414 */ 415 public void printTree() { 416 this.printTree(this.tree, new LinkedList<>()); 417 } 418 419 /** 420 * 獲取推薦結果 421 * 422 * @param instance 423 * @return 424 */ 425 public String classify(String[] instance) { 426 return this.classify(this.tree, instance); 427 } 428 429 public static void main(String[] args) { 430 List<String[]> dataset = new ArrayList<>(); 431 dataset.add(new String[] { "1", "youth", "high", "no", "fair", "no" }); 432 dataset.add(new String[] { "2", "youth", "high", "no", "excellent", "no" }); 433 dataset.add(new String[] { "3", "middle_aged", "high", "no", "fair", "yes" }); 434 dataset.add(new String[] { "4", "senior", "medium", "no", "fair", "yes" }); 435 dataset.add(new String[] { "5", "senior", "low", "yes", "fair", "yes" }); 436 dataset.add(new String[] { "6", "senior", "low", "yes", "excellent", "no" }); 437 dataset.add(new String[] { "7", "middle_aged", "low", "yes", "excellent", "yes" }); 438 dataset.add(new String[] { "8", "youth", "medium", "no", "fair", "no" }); 439 dataset.add(new String[] { "9", "youth", "low", "yes", "fair", "yes" }); 440 dataset.add(new String[] { "10", "senior", "medium", "yes", "fair", "yes" }); 441 dataset.add(new String[] { "11", "youth", "medium", "yes", "excellent", "yes" }); 442 dataset.add(new String[] { "12", "middle_aged", "medium", "no", "excellent", "yes" }); 443 dataset.add(new String[] { "13", "middle_aged", "high", "yes", "fair", "yes" }); 444 dataset.add(new String[] { "14", "senior", "medium", "no", "excellent", "no" }); 445 446 List<Integer> attributes = new ArrayList<>(); 447 attributes.add(4); 448 attributes.add(1); 449 attributes.add(2); 450 attributes.add(3); 451 452 Map<Integer, String> attributeMap = new HashMap<>(); 453 attributeMap.put(1, "Age"); 454 attributeMap.put(2, "Income"); 455 attributeMap.put(3, "Student"); 456 attributeMap.put(4, "Credit_rating"); 457 458 int targetIndex = 5; 459 460 String[] instance = new String[] { "15", "youth", "medium", "yes", "fair" }; 461 462 ID3Tree tree = new ID3Tree(dataset, attributes,attributeMap, targetIndex); 463 System.out.println("start build the tree"); 464 tree.buildTree(); 465 System.out.println("completed build the tree, start print the tree"); 466 tree.printTree(); 467 System.out.println("start classify....."); 468 String result = tree.classify(instance); 469 System.out.println(result); 470 } 471 }
運行java程序的結果是:
Start build the tree..... Completed build the tree, start print the tree..... Age=youth^Student=yes^target=yes Age=youth^Student=no^target=no Age=middle_aged^target=yes Age=senior^Credit_rating=excellent^target=no Age=senior^Credit_rating=fair^target=yes start classify..... yes
五、ID3算法不足
ID3算法運行速度較慢,只能加載內存中的數據,處理的數據集相對於其他算法較小。