ID3算法


一、ID3算法簡單介紹

最早起源於《羅斯昆ID3在悉尼大學。他第一次提出的ID3 1975年在一本書、機器學習、研究所碩士論文。ID3是建立了概念學習系統(CLS)算法。ID3算法是一種基於決策樹的算法。決策樹由決策結點、分支和葉子組成。決策樹中最上面的結點為根節點,每個分支是一個新的決策結點,或者是樹的葉子。每個決策結點代表一個問題或決策,通常對應於待分類對象的屬性。每一個葉子節點代表一種可能的分類結果。沿決策樹從上到下遍歷的過程中,在每個結點都會遇到一個測試,對每個結點上問題的不同的測試輸出導致不同的分支,最后會到達一個葉子節點,這個過程就是利用決策樹進行分類的過程,利用若干個變量來判斷所屬的類別。

二、ID3算法基礎--信息論

 ID3算法是一信息論為基礎,以信息熵和信息增益度為衡量指標,從而實現對數據的分類操作。下面給出一些信息論中的基本概念:

定義1:若存在n個相同概率的消息,則每個消息的概率p是1/n,一個消息傳遞的消息量為-Log2(1/n)

定義2:若有n個消息,其給定概率分布為P=(p1,p2,....pn),則由該分布傳遞的消息量稱為P的熵,記為I(P)=-p1*Log2(p1)-p2*Log2(p2)-...-pn*Log2(pn).

定義3:若一個記錄集合T根據類別屬性的值被分成互相獨立的類C1,C2.....Ck,那么識別T的一個所屬那個類型需要的信息量為Info(T)=I(P),其中P為C1,C2....Ck的概率分布,即P=(|C1|/|T|,|C2|/|T|,....|Ck|/|T|)。

定義4:若我們先根據非類別屬性X的值將T的值分成集合T1,T2,T3....Tn,則確定T中一個元素類的信息量可通過確定Ti的加權平均值來得到,即Info(Ti)的加權平均值為:Info(X,T)=(i=1 to n求和)((|Ti|/|T|)Info(Ti))

定義5:信息增益度是兩個信息量之間的差值,其中一個信息量是確定T的一個元素的信息量,另一個信息量是在得到一個確定屬性X的值后需要確定T一個元素的信息量,公式為:Gain(X,T) = Info(T) = Info(X,T).

ID3算法計算每個屬性的信息增益,並選擇具有最高增益的屬性作為給定集合的測試屬性。對被選擇的屬性創建一個節點,並記錄該節點的屬性標記,對該屬性的每一個值創建一個分支,並對分支進行迭代循環計算信息增益操作。

三、ID3算法步驟示例

 下面給定一個ID3算法的示例:

RID age income student credit_rating buy_compter
1 youth high no fair no
2 youth high no excellent no
3 middle_aged high no fair yes
4 senior medium no fair yes
5 senior low yes fair yes
6 senior low yes excellent no
7 middle_aged low yes excellent yes
8 youth medium no fair no
9 youth low yes fair yes
10 senior medium yes fair yes
11 youth medium yes excellent yes
12 middle_aged medium no excellent yes
13 middle_aged high yes fair yes
14 senior medium no excellent no

 

總數據量是14條,參考屬性是age(youth[5], middle_aged[4], senior[5]),income(high[4], medium[6], low[4]), student(no[7], yes[7]), credit_rating(fair[8], excellent[6])。目標屬性是bug_computer(no[5], yes[9]),希望的結果是能夠得到一個根據age, income, student, credit_rating來推測出來buy_computer的值。假設初始數據集D,參考屬性列表A,下面給定計算步驟:

第一步:在數據集D中就目標屬性的信息熵: Info(buy_computer) = -(5/14)*log2(5/14)-(9/14)*log2(9/14)=0.94

第二步:在數據集D中就參考屬性列表A中的每一個屬性計算,在該屬性值確定的條件下,確定一個bug_computer的信息熵,也就是條件熵。

  age屬性:youth(no[3],yes[2]),middle_aged(no[0],yes[4]),senior(no[2],yes[3]),先分別計算youth、middle_aged、senior的信息熵。

    Infoage(bug_computer|youth) = -(3/5)*log2(3/5) - (2/5)*log2(2/5) = 0.971

    Infoage(bug_computer|middle_aged) = -(4/4)*log2(4/4) - (0/5)*log2(0/5) = 0

    Infoage(bug_computer|senior) = -(2/5)*log2(2/5) - (3/5)*log2(3/5) = 0.971

  則Infoage(buy_computer) = 5/14*0.971 + 4/14 * 0+ 5/14 * 0.971 = 0.694

  同理:Infoincome(buy_computer) = 0.911;Infostudent(buy_computer) = 0.789;Infocredit_rating(buy_computer) = 0.892.

第三步,計算信息增益度,該值如果越大,表示目標屬性在該參考屬性上失去的信息熵越多,那么該屬性就越應該在決策樹的上層。計算結果為:

  Gain(age,bug_computer) = Info(buy_computer) - Infoage(buy_computer) = 0.94 - 0.694 = 0.246

  Gain(income,bug_computer) = Info(buy_computer) - Infoicome(buy_computer) = 0.94 - 0.911 = 0.029

  Gain(student,bug_computer) = Info(buy_computer) - Infostudent(buy_computer) = 0.94 - 0.789 = 0.151

  Gain(credit_rating,bug_computer) = Info(buy_computer) - Infocredit_rating(buy_computer) = 0.94 - 0.892 = 0.048

第四步,選擇信息增益度最大的屬性作為當前節點,此時是age,根據age的不同取值將初始數據集D分隔成以下情況。

  1. age為youth的時候,子數據集是D1:

RID income student credit_rating buy_computer
1 high no fair no
2 high no excellent no
8 medium no fair no
9 low yes fair yes
11 high yes excellent yes

  2. age為middle_aged的時候,子數據集是D2:

RID income student credit_rating buy_computer
3 high no fair yes
7 low yes excellent yes
12 medium no excellent yes
13 high yes fair yes

  3. age為senior的時候,子數據集是D3:

RID income student credit_rating buy_computer
4 medium no fair yes
5 low yes fair yes
6 low yes excellent no
10 medium yes fair yes
14 medium no excellent no

 

第五步,將已經選擇的參考屬性(age)從參考屬性列表A中剔除,針對第四步中產生的子數據集Di使用處理后的參考屬性列表A,再從第一步迭代處理。迭代結束條件為:

  1. 當某種分類中,目標屬性只有一個值,如這里當age為middle_aged的時候。
  2. 當分到某類的時候,目標屬性所有值中,某個值的比例達到了閾值(人為控制),比如可以設為只要buy_computer中某個值達到90%以上,就可以結束迭代。

經過多次迭戈處理,最終會得到一個樹結構如下圖所示:

獲得規則是:

IF AGE=middle_aged, THEN buy_computer = yes

IF AGE = youth AND STUDENT = yes, THEN buy_computer = yes

IF AGE = youth AND STUDENT = no, THEN buy_computer = no

IF AGE = senior AND CREDIT_RATING = excellent, THEN buy_computer = no

IF AGE = senior AND CREDIT_RATING = fair, THEN buy_computer = yes

SO, If the instance are ("15", "youth", "medium", "yes", "fair"), the predicted value of buy_computer is "yes".

 

四、ID3算法程序實現

下面分別給出python和java兩種語言的ID3算法的實現:

Python程序:

  1 # -*- coding: utf-8 -*-
  2 
  3 
  4 class Node:
  5     '''Represents a decision tree node.
  6     
  7     '''
  8     def __init__(self, parent = None, dataset = None):
  9         self.dataset = dataset # 落在該結點的訓練實例集
 10         self.result = None # 結果類標簽
 11         self.attr = None # 該結點的分裂屬性ID
 12         self.childs = {} # 該結點的子樹列表,key-value pair: (屬性attr的值, 對應的子樹)
 13         self.parent = parent # 該結點的父親結點
 14         
 15 
 16 
 17 def entropy(props):
 18     if (not isinstance(props, (tuple, list))):
 19         return None
 20     
 21     from math import log
 22     log2 = lambda x:log(x)/log(2) # an anonymous function
 23     e = 0.0
 24     for p in props:
 25         if p != 0:
 26             e = e - p * log2(p)
 27     return e
 28 
 29 
 30 def info_gain(D, A, T = -1, return_ratio = False):
 31     '''特征A對訓練數據集D的信息增益 g(D,A)
 32     
 33     g(D,A)=entropy(D) - entropy(D|A)
 34             假設數據集D的每個元組的最后一個特征為類標簽
 35     T為目標屬性的ID,-1表示元組的最后一個元素為目標'''
 36     if (not isinstance(D, (set, list))):
 37         return None
 38     if (not type(A) is int):
 39         return None
 40     C = {} # 類別計數字典
 41     DA = {} # 特征A的取值計數字典
 42     CDA = {} # 類別和特征A的不同組合的取值計數字典
 43     for t in D:
 44         C[t[T]] = C.get(t[T], 0) + 1
 45         DA[t[A]] = DA.get(t[A], 0) + 1
 46         CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1
 47 
 48     PC = map(lambda x : 1.0 * x / len(D), C.values()) # 類別的概率列表,即目標屬性的概率,信息熵
 49     entropy_D = entropy(tuple(PC)) # map返回的對象類型為map,需要強制類型轉換為元組
 50 
 51 
 52     PCDA = {} # 特征A的每個取值給定的條件下各個類別的概率(條件概率)
 53     for key, value in CDA.items():
 54         a = key[1] # 特征A
 55         pca = value / DA[a]
 56         PCDA.setdefault(a, []).append(pca)
 57     
 58     condition_entropy = 0.0
 59     for a, v in DA.items():
 60         p = v / len(D)
 61         e = entropy(PCDA[a])
 62         condition_entropy += e * p
 63     
 64     if (return_ratio):
 65         return (entropy_D - condition_entropy) / entropy_D
 66     else:
 67         return entropy_D - condition_entropy
 68     
 69 def get_result(D, T = -1):
 70     '''獲取數據集D中實例數最大的目標特征T的值'''
 71     if (not isinstance(D, (set, list))):
 72         return None
 73     if (not type(T) is int):
 74         return None
 75     count = {}
 76     for t in D:
 77         count[t[T]] = count.get(t[T], 0) + 1
 78     max_count = 0
 79     for key, value in count.items():
 80         if (value > max_count):
 81             max_count = value
 82             result = key
 83     return result 
 84 
 85 
 86 def devide_set(D, A):
 87     '''根據特征A的值把數據集D分裂為多個子集'''
 88     if (not isinstance(D, (set, list))):
 89         return None
 90     if (not type(A) is int):
 91         return None
 92     subset = {}
 93     for t in D:
 94         subset.setdefault(t[A], []).append(t)
 95     return subset
 96 
 97 
 98 def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"):
 99     '''根據數據集D和特征集A構建決策樹.
100     
101     T為目標屬性在元組中的索引 . 目前支持ID3和C4.5兩種算法'''
102     if (Tree != None and not isinstance(Tree, Node)):
103         return None
104     if (not isinstance(D, (set, list))):
105         return None
106     if (not type(A) is set):
107         return None
108     
109     if (None == Tree):
110         Tree = Node(None, D)
111     subset = devide_set(D, T)
112     if (len(subset) <= 1):
113         for key in subset.keys():
114             Tree.result = key
115         del(subset)
116         return Tree
117     if (len(A) <= 0):
118         Tree.result = get_result(D)
119         return Tree
120     use_gain_ratio = False if algo == "ID3" else True
121 
122     max_gain = 0
123     for a in A:
124         gain = info_gain(D, a, return_ratio = use_gain_ratio)
125         if (gain > max_gain):
126             max_gain = gain
127             attr_id = a # 獲取信息增益最大的特征
128     if (max_gain < threshold):
129         Tree.result = get_result(D)
130         return Tree
131     Tree.attr = attr_id
132     subD = devide_set(D, attr_id)
133     del(D[:]) # 刪除中間數據,釋放內存
134     Tree.dataset = None
135     A.discard(attr_id) # 從特征集中排查已經使用過的特征
136     for key in subD.keys():
137         tree = Node(Tree, subD.get(key))
138         Tree.childs[key] = tree
139         build_tree(subD.get(key), A, threshold, T, tree)
140     return Tree
141 
142 
143 def print_brance(brance, target):
144     odd = 0
145     for e in brance:
146         print e,('=' if odd == 0 else ''),
147         odd = 1 - odd
148     print "target =", target
149 
150 
151 def print_tree(Tree, stack = []): 
152     if (None == Tree):
153         return
154     if (None != Tree.result):
155         print_brance(stack, Tree.result)
156         return
157     stack.append(Tree.attr)
158     for key, value in Tree.childs.items():
159         stack.append(key)
160         print_tree(value, stack)
161         stack.pop()
162     stack.pop()
163     
164 def classify(Tree, instance):
165     if (None == Tree):
166         return None
167     if (None != Tree.result):
168         return Tree.result
169     if instance[Tree.attr] in Tree.childs:
170         return classify(Tree.childs[instance[Tree.attr]], instance)
171     else:
172         return None
173 
174 dataset = [
175    ("青年", "", "", "一般", "")
176    ,("青年", "", "", "", "")
177    ,("青年", "", "", "", "")
178    ,("青年", "", "", "一般", "")
179    ,("青年", "", "", "一般", "")
180    ,("中年", "", "", "一般", "")
181    ,("中年", "", "", "", "")
182    ,("老年", "", "", "非常好", "")
183    ,("老年", "", "", "一般", "")
184    ,("老年", "", "", "一般", "")
185    ,("老年", "", "", "一般", "")
186    ,("老年", "", "", "", "")
187    ,("老年", "", "", "一般", "")
188    ,("老年", "", "", "一般", "")
189    ,("老年", "", "", "一般", "")
190 ]
191 
192 s = set(range(0, len(dataset[0]) - 1))
193 s = set([0,1,3,4])
194 T = build_tree(dataset, s)
195 print_tree(T)
196 print(classify(T, ("老年", "", "", "一般", "")))
197 print(classify(T, ("老年", "", "", "一般", "")))
198 print(classify(T, ("老年", "", "", "", "")))
199 print(classify(T, ("青年", "", "", "", "")))
200 print(classify(T, ("中年", "", "", "", "")))
ID3--Python

 

該python程序的訓練集不是上面給定的這個列子,輸出結果為:

0 = 青年 ∧ 1 = 否 ∧ target = 否
0 = 青年 ∧ 1 = 是 ∧ target = 是
0 = 中年 ∧ target = 否
0 = 老年 ∧ 3 = 好 ∧ target = 是
0 = 老年 ∧ 3 = 非常好 ∧ target = 是
0 = 老年 ∧ 3 = 一般 ∧ 4 = 否 ∧ target = 否
0 = 老年 ∧ 3 = 一般 ∧ 4 = 是 ∧ target = 是
否
是
是
是
否
[Finished in 0.3s]

  

Java程序,該程序的數據集是上面給定的例子,代碼及結果如下:

  1   2 
  3 import java.util.ArrayList;
  4 import java.util.Collection;
  5 import java.util.Deque;
  6 import java.util.HashMap;
  7 import java.util.LinkedList;
  8 import java.util.List;
  9 import java.util.Map;
 10 
 11 public class ID3Tree {
 12     private List<String[]> datas;
 13     private List<Integer> attributes;
 14     private double threshold = 0.0001;
 15     private int targetIndex = 1;
 16     private Node tree;
 17     private Map<Integer, String> attributeMap;
 18 
 19     protected ID3Tree() {
 20         super();
 21     }
 22 
 23     public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, int targetIndex) {
 24         this(datas, attributes, attributeMap, 0.0001, targetIndex, null);
 25     }
 26 
 27     public ID3Tree(List<String[]> datas, List<Integer> attributes, Map<Integer, String> attributeMap, double threshold, int targetIndex, Node tree) {
 28         super();
 29         this.datas = datas;
 30         this.attributes = attributes;
 31         this.attributeMap = attributeMap;
 32         this.threshold = threshold;
 33         this.targetIndex = targetIndex;
 34         this.tree = tree;
 35     }
 36 
 37     /**
 38      * 節點對象
 39      * 
 40      * @author Gerry.Liu
 41      *
 42      */
 43     class Node {
 44         private List<String[]> dataset; // 落在該節點上的訓練實訓集
 45         private String result; // 結果類標簽
 46         private int attr; // 該節點的分裂屬性ID,下標
 47         private Node parent; // 該節點的父節點
 48         private Map<String, List<Node>> childs; // 該節點的子節點集合
 49 
 50         public Node(List<String[]> datas, Node parent) {
 51             this.dataset = datas;
 52             this.parent = parent;
 53             this.childs = new HashMap<>();
 54         }
 55     }
 56 
 57     class KeyValue {
 58         private String first;
 59         private String second;
 60 
 61         public KeyValue(String first, String second) {
 62             super();
 63             this.first = first;
 64             this.second = second;
 65         }
 66 
 67         @Override
 68         public int hashCode() {
 69             final int prime = 31;
 70             int result = 1;
 71             result = prime * result + getOuterType().hashCode();
 72             result = prime * result + ((first == null) ? 0 : first.hashCode());
 73             result = prime * result + ((second == null) ? 0 : second.hashCode());
 74             return result;
 75         }
 76 
 77         @Override
 78         public boolean equals(Object obj) {
 79             if (this == obj)
 80                 return true;
 81             if (obj == null)
 82                 return false;
 83             if (getClass() != obj.getClass())
 84                 return false;
 85             KeyValue other = (KeyValue) obj;
 86             if (!getOuterType().equals(other.getOuterType()))
 87                 return false;
 88             if (first == null) {
 89                 if (other.first != null)
 90                     return false;
 91             } else if (!first.equals(other.first))
 92                 return false;
 93             if (second == null) {
 94                 if (other.second != null)
 95                     return false;
 96             } else if (!second.equals(other.second))
 97                 return false;
 98             return true;
 99         }
100 
101         private ID3Tree getOuterType() {
102             return ID3Tree.this;
103         }
104     }
105 
106     /**
107      * 根據概率計算信息熵,計算規則是:<br/>
108      * entropy(p1,p2....pn) = -p1*log2(p1) -p2*log2(p2)-.....-pn*log2(pn)
109      *
110      * @param props
111      * @return
112      */
113     private double entropy(List<Double> props) {
114         if (props == null || props.isEmpty()) {
115             return 0;
116         } else {
117             double result = 0;
118             for (double p : props) {
119                 if (p > 0) {
120                     result = result - p * Math.log(p) / Math.log(2);
121                 }
122             }
123             return result;
124         }
125     }
126 
127     /**
128      * 計算概率
129      * 
130      * @param totalRecords
131      * @param counts
132      * @return
133      */
134     private List<Double> calcProbability(int totalRecords, Collection<Integer> counts) {
135         if (totalRecords == 0 || counts == null || counts.isEmpty()) {
136             return null;
137         }
138 
139         List<Double> result = new ArrayList<>();
140         for (int count : counts) {
141             result.add(1.0 * count / totalRecords);
142         }
143         return result;
144     }
145 
146     /**
147      * 獲取信息增益Gain(datas,attribute)<br/>
148      * 特征屬性attribute(A)對訓練數據集datas(D)的信息增益<br/>
149      * g(D,A) = entropy(D) - entropy(D|A)<br/>
150      * 
151      * @param datas
152      *            訓練數據集
153      * @param attributeIndex
154      *            特征屬性下標
155      * @param targetAttributeIndex
156      *            目標屬性下標
157      * @return
158      */
159     private double infoGain(List<String[]> datas, int attributeIndex, int targetAttributeIndex) {
160         if (datas == null || datas.isEmpty()) {
161             return 0;
162         }
163 
164         Map<String, Integer> targetAttributeCountMap = new HashMap<String, Integer>(); // 類別(目標屬性)計數
165         Map<String, Integer> featureAttributesCountMap = new HashMap<>(); // 特征屬性上的取值計數
166         Map<KeyValue, Integer> tfAttributeCountMap = new HashMap<>(); // 類別和特征屬性的不同組合的計數
167 
168         for (String[] arrs : datas) {
169             String tv = arrs[targetAttributeIndex];
170             String fv = arrs[attributeIndex];
171             if (targetAttributeCountMap.containsKey(tv)) {
172                 targetAttributeCountMap.put(tv, targetAttributeCountMap.get(tv) + 1);
173             } else {
174                 targetAttributeCountMap.put(tv, 1);
175             }
176             if (featureAttributesCountMap.containsKey(fv)) {
177                 featureAttributesCountMap.put(fv, featureAttributesCountMap.get(fv) + 1);
178             } else {
179                 featureAttributesCountMap.put(fv, 1);
180             }
181             KeyValue key = new KeyValue(fv, tv);
182             if (tfAttributeCountMap.containsKey(key)) {
183                 tfAttributeCountMap.put(key, tfAttributeCountMap.get(key) + 1);
184             } else {
185                 tfAttributeCountMap.put(key, 1);
186             }
187         }
188 
189         int totalDataSize = datas.size();
190         // 計算概率
191         List<Double> probabilitys = calcProbability(totalDataSize, targetAttributeCountMap.values());
192         // 計算目標屬性的信息熵
193         double entropyDatas = this.entropy(probabilitys);
194 
195         // 計算條件概率
196         // 第一步,計算目標屬性的各種取值,在特征屬性確定的條件下的情況
197         Map<String, List<Double>> pcda = new HashMap<>();
198         for (Map.Entry<KeyValue, Integer> entry : tfAttributeCountMap.entrySet()) {
199             String key = entry.getKey().first;
200             double pca = 1.0 * entry.getValue() / featureAttributesCountMap.get(key);
201             if (pcda.containsKey(key)) {
202                 pcda.get(key).add(pca);
203             } else {
204                 List<Double> list = new ArrayList<Double>();
205                 list.add(pca);
206                 pcda.put(key, list);
207             }
208         }
209         // 第二步,針對每個特征屬性的值取信息熵,並獲取平均熵
210         double conditionEntropy = 0.0;
211         for (Map.Entry<String, Integer> entry : featureAttributesCountMap.entrySet()) {
212             double p = 1.0 * entry.getValue() / totalDataSize;
213             double e = this.entropy(pcda.get(entry.getKey()));
214             conditionEntropy += e * p;
215         }
216         return entropyDatas - conditionEntropy;
217     }
218 
219     /**
220      * 獲取數據集中目標屬性中,實例值個數最大的目標特征值
221      * 
222      * @param datas
223      * @param targetAttributeIndex
224      * @return
225      */
226     private String getResult(List<String[]> datas, int targetAttributeIndex) {
227         if (datas == null || datas.isEmpty()) {
228             return null;
229         } else {
230             String result = "";
231             Map<String, Integer> countMap = new HashMap<>();
232             for (String[] arr : datas) {
233                 String key = arr[targetAttributeIndex];
234                 if (countMap.containsKey(key)) {
235                     countMap.put(key, countMap.get(key) + 1);
236                 } else {
237                     countMap.put(key, 1);
238                 }
239             }
240 
241             int maxCount = -1;
242             for (Map.Entry<String, Integer> entry : countMap.entrySet()) {
243                 if (entry.getValue() > maxCount) {
244                     maxCount = entry.getValue();
245                     result = entry.getKey();
246                 }
247             }
248             return result;
249         }
250     }
251 
252     /**
253      * 按照特征屬性的值將數據集D分裂成為多個子集
254      * 
255      * @param datas
256      *            數據集
257      * @param attributeIndex
258      *            特征屬性下標
259      * @return
260      */
261     private Map<String, List<String[]>> devideDatas(List<String[]> datas, int attributeIndex) {
262         Map<String, List<String[]>> subdatas = new HashMap<>();
263         if (datas != null && !datas.isEmpty()) {
264             for (String[] arr : datas) {
265                 String key = arr[attributeIndex];
266                 if (subdatas.containsKey(key)) {
267                     subdatas.get(key).add(arr);
268                 } else {
269                     List<String[]> list = new ArrayList<>();
270                     list.add(arr);
271                     subdatas.put(key, list);
272                 }
273             }
274         }
275         return subdatas;
276     }
277 
278     /**
279      * 打印決策樹
280      * 
281      * @param tree
282      * @param stock
283      */
284     private void printTree(Node tree, Deque<Object> stock) {
285         if (tree == null) {
286             return;
287         }
288 
289         if (tree.result != null) {
290             this.printBrance(stock, tree.result);
291         } else {
292             stock.push(this.attributeMap.get(tree.attr));
293             for (Map.Entry<String, List<Node>> entry : tree.childs.entrySet()) {
294                 stock.push(entry.getKey());
295                 for (Node node : entry.getValue()) {
296                     this.printTree(node, stock);
297                 }
298                 stock.pop();
299             }
300             stock.pop();
301         }
302     }
303 
304     /**
305      * 輸出Node表示的決策樹的規則
306      * 
307      * @param tree
308      */
309     private void printBrance(Deque<Object> stock, String target) {
310         StringBuffer sb = new StringBuffer();
311         int odd = 0;
312         for (Object e : stock) {
313             sb.insert(0, odd == 0 ? "^" : "=").insert(0, e);
314             // sb.append(e).append(odd == 0 ? "=" : "^");
315             odd = 1 - odd;
316         }
317         sb.append("target=").append(target);
318         System.out.println(sb.toString());
319     }
320 
321     /**
322      * 創建一個決策樹
323      * 
324      * @param datas
325      * @param attributes
326      * @param threshold
327      * @param targetIndex
328      * @param tree
329      * @return
330      */
331     private Node buildTree(List<String[]> datas, List<Integer> attributes, double threshold, int targetIndex, Node tree) {
332         if (tree == null) {
333             tree = new Node(datas, null);
334         }
335         // 分隔數據集,返回的數據集為empty或者是有數據,不會為null
336         Map<String, List<String[]>> subDatas = this.devideDatas(datas, targetIndex);
337         if (subDatas.size() <= 1) {
338             // 這里只會有一個key
339             for (String key : subDatas.keySet()) {
340                 tree.result = key;
341             }
342         } else if (attributes == null || attributes.size() < 1) {
343             // 沒有特征集,那么直接獲取最多的值
344             tree.result = this.getResult(datas, targetIndex);
345         } else {
346             double maxGain = 0;
347             int attr = 0;
348 
349             for (int attribute : attributes) {
350                 double gain = this.infoGain(datas, attribute, targetIndex);
351                 if (gain > maxGain) {
352                     maxGain = gain;
353                     attr = attribute;// 最大的信息增益下標
354                 }
355             }
356 
357             if (maxGain < threshold) {
358                 // 達到收益條件
359                 tree.result = this.getResult(datas, targetIndex);
360             } else {
361                 // 沒有達到結束條件,繼續進行
362                 tree.attr = attr;
363                 subDatas = this.devideDatas(datas, attr);
364                 tree.dataset = null;
365                 attributes.remove(Integer.valueOf(attr));
366                 for (String key : subDatas.keySet()) {
367                     Node childTree = new Node(subDatas.get(key), tree);
368                     if (tree.childs.containsKey(key)) {
369                         tree.childs.get(key).add(childTree);
370                     } else {
371                         List<Node> childs = new ArrayList<>();
372                         childs.add(childTree);
373                         tree.childs.put(key, childs);
374                     }
375                     this.buildTree(subDatas.get(key), attributes, threshold, targetIndex, childTree);
376                 }
377             }
378         }
379         return tree;
380     }
381 
382     /**
383      * 根據決策規則獲取推薦值
384      * 
385      * @param instance
386      * @return
387      */
388     private String classify(Node tree, String[] instance) {
389         if (tree == null) {
390             return null;
391         }
392         if (tree.result != null) {
393             return tree.result;
394         }
395         if (tree.childs.containsKey(instance[tree.attr])) {
396             List<Node> nodes = tree.childs.get(instance[tree.attr]);
397             for (Node node : nodes) {
398                 return this.classify(node, instance);
399             }
400         }
401         return null;
402     }
403 
404     /**
405      * 生產決策樹
406      */
407     public void buildTree() {
408         this.tree = new Node(this.datas, null);
409         this.buildTree(datas, attributes, threshold, targetIndex, tree);
410     }
411 
412     /**
413      * 打印生產的規則
414      */
415     public void printTree() {
416         this.printTree(this.tree, new LinkedList<>());
417     }
418 
419     /**
420      * 獲取推薦結果
421      * 
422      * @param instance
423      * @return
424      */
425     public String classify(String[] instance) {
426         return this.classify(this.tree, instance);
427     }
428 
429     public static void main(String[] args) {
430         List<String[]> dataset = new ArrayList<>();
431         dataset.add(new String[] { "1", "youth", "high", "no", "fair", "no" });
432         dataset.add(new String[] { "2", "youth", "high", "no", "excellent", "no" });
433         dataset.add(new String[] { "3", "middle_aged", "high", "no", "fair", "yes" });
434         dataset.add(new String[] { "4", "senior", "medium", "no", "fair", "yes" });
435         dataset.add(new String[] { "5", "senior", "low", "yes", "fair", "yes" });
436         dataset.add(new String[] { "6", "senior", "low", "yes", "excellent", "no" });
437         dataset.add(new String[] { "7", "middle_aged", "low", "yes", "excellent", "yes" });
438         dataset.add(new String[] { "8", "youth", "medium", "no", "fair", "no" });
439         dataset.add(new String[] { "9", "youth", "low", "yes", "fair", "yes" });
440         dataset.add(new String[] { "10", "senior", "medium", "yes", "fair", "yes" });
441         dataset.add(new String[] { "11", "youth", "medium", "yes", "excellent", "yes" });
442         dataset.add(new String[] { "12", "middle_aged", "medium", "no", "excellent", "yes" });
443         dataset.add(new String[] { "13", "middle_aged", "high", "yes", "fair", "yes" });
444         dataset.add(new String[] { "14", "senior", "medium", "no", "excellent", "no" });
445 
446         List<Integer> attributes = new ArrayList<>();
447         attributes.add(4);
448         attributes.add(1);
449         attributes.add(2);
450         attributes.add(3);
451 
452         Map<Integer, String> attributeMap = new HashMap<>();
453         attributeMap.put(1, "Age");
454         attributeMap.put(2, "Income");
455         attributeMap.put(3, "Student");
456         attributeMap.put(4, "Credit_rating");
457 
458         int targetIndex = 5;
459 
460         String[] instance = new String[] { "15", "youth", "medium", "yes", "fair" };
461 
462         ID3Tree tree = new ID3Tree(dataset, attributes,attributeMap, targetIndex);
463         System.out.println("start build the tree");
464         tree.buildTree();
465         System.out.println("completed build the tree, start print the tree");
466         tree.printTree();
467         System.out.println("start classify.....");
468         String result = tree.classify(instance);
469         System.out.println(result);
470     }
471 }
ID3--Java

 運行java程序的結果是:

Start build the tree.....
Completed build the tree, start print the tree.....
Age=youth^Student=yes^target=yes
Age=youth^Student=no^target=no
Age=middle_aged^target=yes
Age=senior^Credit_rating=excellent^target=no
Age=senior^Credit_rating=fair^target=yes
start classify.....
yes

 五、ID3算法不足 

ID3算法運行速度較慢,只能加載內存中的數據,處理的數據集相對於其他算法較小。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM