編程實現基於信息熵進行划分選擇的決策樹算法(ID3,C4.5)


1.題目理解

 

編程實現基於信息熵進行划分選擇的決策樹算法(包括ID3,C4.5兩種算法),並為表4.3中的數據生成一棵決策樹。

 

2.算法原理

  2.1信息熵

  度量樣本集合純度最常用的一種指標, 信息熵的值越小,則樣本集合D的純度越高。

  

 

  2.2信息增益(ID3中使用)

  假定離散屬性α有V個可能的取值{ a1 ,…,av},若使用α來對樣本集D進行划分,則會產生V個分支結點,其中第v個分支結點包含了D中所有在屬性α上取值為 av的樣本,記為Dv 。計算出Dv 的信息熵,再考慮到不同的分支結點所包含的樣本數不同,給分支結點賦予權重|Dv|/|D| .即樣本數越多的分支結點的影響越大,於是可計算出用屬性α對樣本集D進行划分所獲得的“信息增益”:

  

  一般而言,信息增益越大,意味着使用屬性a進行划分所獲得的“純度提升”越大,因此可以用信息增益來作為決策樹的划分屬性的准則。

  不足:實際上,信息增益准則對可取值數目較多的屬性有所偏好:當某一屬性a可取值的數目較多時,每個屬性值下的樣本集合Dv數目較小,相對其他屬性而言,樣本集合Dv的純度更高,從而導致該屬性的信息增益偏大,影響決策樹的泛化能力。

 

  2.3信息增益率(C4.5中使用)

 

  為了減少信息增益准則的不利影響,使用增益率來選擇最優屬性划分,增益率定義為:

 

 

 

其中,

 

  不足:通常屬性a的可能取值數目越多,固有值通常會越大;當屬性可取值數目較少時,固有值較小,導致增益率可能偏大,即增益率准則對可取值數目較少的屬性有所偏好。

 

  優化:綜合信息增益和增益率的特點,C4.5算法並不是直接選擇增益率最大的候選划分屬性,而是使用了一個啟發式:先從候選划分屬性中找出信息增益高於平均水平的屬性,再從中選擇增益率最高的。

  2.4決策樹的生成算法

  假設訓練集是D,屬性集是A,遞歸生成決策樹。

  首先生成結點node,如果D中的樣本全部屬於同一類別C,將node標記為C類葉結點,返回上一層遞歸;

  如果屬性集A是空集或者D中的樣本有完全相同的屬性值,將node標記為葉結點,類別是此時D中樣本數最多的類,返回上一層遞歸;

  否則,根據信息增益或增益率選出最優划分屬性a,對a中的每一個屬性值生成一個分支,選擇D中對應屬性值的樣本作為子集Dv :如果某屬性值對應的子集為空集,將該分支對應的子結點標記為葉結點,類別是D中樣本數最多的類;如果某屬性值對應的子集不是空集,將Dv 和A-a作為輸入遞歸生成決策樹。

 

3.算法設計和關鍵代碼

  3.1計算信息熵

  在西瓜數據集中,統計好瓜和和壞瓜的數目;計算好瓜、壞瓜分別占西瓜總數的比例;根據公式計算出信息熵;

 1 # 信息熵
 2 def entropy(melons):
 3     m_num = len(melons)     # 瓜數
 4     good_num = 0
 5     bad_num = 0
 6     for i in range(m_num):
 7         if melons[i][7]==1: good_num +=1
 8     bad_num = m_num - good_num
 9     p_good = good_num/m_num
10     p_bad = bad_num/m_num
11     ent = -(p_good * math.log(p_good, 2) + p_bad * math.log(p_bad, 2))
12     return ent

 

  3.2計算不同屬性的信息增益(屬性分為連續值和離散值計算)並選擇最佳屬性(ID3樹)

 

  離散值有明顯的類別可以直接計算,連續值使用二分法進行分類,將每種不同的分法都看作一類,最終與離散值一起選擇使信息增益最高的屬性;

 

  求解信息熵時和信息增益時,要注意每類西瓜數不能為零;如果Dv類西瓜數量為0,則對應的信息熵為0;信息增益同理。(防止除數為0和log0的情況)。

 

  1 # 計算信息增益
  2 def Gain(melons, chara):
  3     feature_ent = 0
  4     gain = 0
  5     m_num = len(melons)
  6 
  7     # 連續density
  8     if chara >= 6:
  9         d1_good = 0  # 小於等於div
 10         d1_bad = 0
 11         d2_good = 0
 12         d2_bad = 0
 13 
 14         # for div in divide_point:
 15         for j in range(m_num):
 16             if melons[j][6] <= divide_point[chara - 6] and melons[j][7] == 1: d1_good += 1
 17             if melons[j][6] <= divide_point[chara - 6] and melons[j][7] == 0: d1_bad += 1
 18             if melons[j][6] > divide_point[chara - 6] and melons[j][7] == 1: d2_good += 1
 19             if melons[j][6] > divide_point[chara - 6] and melons[j][7] == 0: d2_bad += 1
 20             d1 = d1_good + d1_bad
 21             d2 = d2_good + d2_bad
 22             # 防止除以0
 23             if d1_good==0 and d1_bad==0:
 24                 p1g = 0
 25                 p1b = 0
 26             else:
 27                 p1g = d1_good/d1
 28                 p1b = d1_bad/d1
 29             if d2_good==0 and d2_bad==0:
 30                 p2g = 0
 31                 p2b = 0
 32             else:
 33                 p2g = d2_good/d2
 34                 p2b = d2_bad/d2
 35             # 防止log0
 36             if d1_good != 0 and d1_bad != 0:
 37                 entd1 = -d1 / m_num * (-(p1g * math.log(p1g, 2) + p1b * math.log(p1b, 2)))
 38             elif d1_good==0 and d1_bad!=0:
 39                 entd1 = -d1 / m_num *(-p1b * math.log(p1b, 2))
 40             elif d1_good!=0 and d1_bad==0:
 41                 entd1 = -d1 / m_num * (-p1g * math.log(p1g, 2))
 42             else:
 43                 entd1 = 0
 44 
 45             if d2_good != 0 and d2_bad != 0:
 46                 entd2 = -d2 / m_num * (-(p2g * math.log(p2g, 2) + p2b * math.log(p2b, 2)))
 47             elif d2_good==0 and d2_bad!=0:
 48                 entd2 = -d2 / m_num *(-p2b * math.log(p2b, 2))
 49             elif d2_good!=0 and d2_bad==0:
 50                 entd2 = -d2 / m_num * (-p2g * math.log(p2g, 2))
 51             else:
 52                 entd2 = 0
 53             gain = entropy(melons) + entd1 + entd2
 54 
 55     # 觸感
 56     elif chara==5:
 57         d1_good = 0
 58         d1_bad = 0
 59         d2_good =0
 60         d2_bad = 0
 61         for i in range(m_num):
 62             if melons[i][5] == 0 and melons[i][7] == 1: d1_good += 1
 63             if melons[i][5] == 0 and melons[i][7] == 0: d1_bad += 1
 64             if melons[i][5] == 1 and melons[i][7] == 1: d2_good += 1
 65             if melons[i][5] == 1 and melons[i][7] == 0: d2_bad += 1
 66         d1 = d1_good + d1_bad
 67         d2 = d2_good + d2_bad
 68 
 69         if d1 == 0:
 70             entd1 = 0
 71         elif d1_good == 0:
 72             p1b = d1_bad / d1
 73             entd1 = -(p1b * math.log(p1b, 2))
 74         elif d1_bad == 0:
 75             p1g = d1_good / d1
 76             entd1 = -(p1g * math.log(p1g, 2))
 77         elif d1_good != 0 and d1_bad != 0:
 78             p1g = d1_good / d1
 79             p1b = d1_bad / d1
 80             entd1 = -(p1g * math.log(p1g, 2) + p1b * math.log(p1b, 2))
 81 
 82         if d2 == 0:
 83             entd2 = 0
 84         elif d2_good == 0:
 85             p2b = d2_bad / d2
 86             entd2 = -(p2b * math.log(p2b, 2))
 87         elif d2_bad == 0:
 88             p2g = d2_good / d2
 89             entd2 = -(p2g * math.log(p2g, 2))
 90         elif d2_good != 0 and d2_bad != 0:
 91             p2g = d2_good / d2
 92             p2b = d2_bad / d2
 93             entd2 = -(p2g * math.log(p2g, 2) + p2b * math.log(p2b, 2))
 94         feature_ent = feature_ent-(entd1*d1/m_num+entd2*d2/m_num)
 95         gain = entropy(melons) + feature_ent
 96 
 97     # 其余離散特征
 98     else: # chara==0 or chara==1 or chara==2 or chara==3 or chara==4:
 99         attr_mat = [['青綠', '烏黑', '淺白'], ['蜷縮', '稍蜷', '硬挺'], ['濁響', '沉悶', '清脆'], ['清晰', '稍糊', '模糊'], ['凹陷', '稍凹', '平坦']]
100         d1_good = 0
101         d1_bad = 0
102         d2_good = 0
103         d2_bad = 0
104         d3_good = 0
105         d3_bad = 0
106         for i in range(m_num):
107             if melons[i][chara] == 0 and melons[i][7] == 1: d1_good += 1
108             if melons[i][chara] == 0 and melons[i][7] == 0: d1_bad += 1
109             if melons[i][chara] == 1 and melons[i][7] == 1: d2_good += 1
110             if melons[i][chara] == 1 and melons[i][7] == 0: d2_bad += 1
111             if melons[i][chara] == 2 and melons[i][7] == 1: d3_good += 1
112             if melons[i][chara] == 2 and melons[i][7] == 0: d3_bad += 1
113         d1 = d1_good + d1_bad
114         d2 = d2_good + d2_bad
115         d3 = d3_good + d3_bad
116         if d1 == 0:
117             entd1 = 0
118         elif d1_good == 0:
119             p1b = d1_bad / d1
120             entd1 = -(p1b * math.log(p1b, 2))
121         elif d1_bad == 0:
122             p1g = d1_good / d1
123             entd1 = -(p1g * math.log(p1g, 2))
124         elif d1_good != 0 and d1_bad != 0:
125             p1g = d1_good / d1
126             p1b = d1_bad / d1
127             entd1 = -(p1g * math.log(p1g, 2) + p1b * math.log(p1b, 2))
128 
129         if d2 == 0:
130             entd2 = 0
131         elif d2_good == 0:
132             p2b = d2_bad / d2
133             entd2 = -(p2b * math.log(p2b, 2))
134         elif d2_bad == 0:
135             p2g = d2_good / d2
136             entd2 = -(p2g * math.log(p2g, 2))
137         elif d2_good != 0 and d2_bad != 0:
138             p2g = d2_good / d2
139             p2b = d2_bad / d2
140             entd2 = -(p2g * math.log(p2g, 2) + p2b * math.log(p2b, 2))
141 
142         if d3 == 0:
143             entd3 = 0
144         elif d3_good == 0:
145             p3b = d3_bad / d3
146             entd3 = -(p3b * math.log(p3b, 2))
147         elif d3_bad == 0:
148             p3g = d3_good / d3
149             entd3 = -(p3g * math.log(p3g, 2))
150         elif d3_good != 0 and d3_bad != 0:
151             p3g = d3_good / d3
152             p3b = d3_bad / d3
153             entd3 = -(p3g * math.log(p3g, 2) + p3b * math.log(p3b, 2))
154 
155         feature_ent = feature_ent-(entd1 * d1 / m_num + entd2 * d2 / m_num + entd3 * d3 / m_num)
156         gain = entropy(melons) + feature_ent
157 
158     return [gain, chara]
1 def choose_best_feature(melons, A):
2     max_ent= Gain(melons, A[0])
3     for i in range(len(A)):
4         ent_temp = Gain(melons, A[i])
5         if ent_temp[0]>max_ent[0]:
6             max_ent = ent_temp
7     return max_ent

 

  3.3計算不同屬性的信息增益率並選擇最佳屬性

 1 # 計算增益率
 2 def Gainratio(melons, chara):
 3     # 離散值
 4     if chara<5:
 5         in_value = 0
 6         num0 = len(melons)
 7         for i in range(3):
 8             num = 0
 9             for dd in melons:
10                 if dd[chara]==i:
11                     num += 1
12             if num!=0:
13                 in_value -= abs(num/num0)*math.log(abs(num/num0), 2)
14         gain = Gain(melons, chara)
15         g_ratio = gain[0]/in_value
16 
17     elif chara==5:
18         in_value = 0
19         num0 = len(melons)
20         for i in range(2):
21             num = 0
22             for dd in melons:
23                 if dd[chara] == i:
24                     num += 1
25             if num != 0:
26                 in_value -= abs(num / num0) * math.log(abs(num / num0), 2)
27         gain = Gain(melons, chara)
28         g_ratio = gain[0] / in_value
29 
30     else:
31         # 連續值
32         in_value = 0
33         num0 = len(melons)
34         s = 0
35         l = 0
36         for j in melons:
37             if j[6]>divide_point[chara-6]:
38                 l += 1
39             else:
40                 s += 1
41         if l!=0 and s!=0:
42             in_value -= abs(l / num0) * math.log(abs(l / num0), 2)
43             in_value -= abs(s / num0) * math.log(abs(s / num0), 2)
44         elif s==0 and l!=0:
45             in_value -= abs(l/num0)*math.log(abs(l/num0), 2)
46         elif l==0 and s!=0:
47             in_value -= abs(s/num0)*math.log(abs(s/num0), 2)
48         # if in_value==0:
49         #     g_ratio = 0
50         # else:
51         gain = Gain(melons, chara)
52         g_ratio = gain[0] / in_value
53 
54     return [g_ratio, chara]

 

1 def choose_best_feature(melons, A):
2     new_ent, new_A = choose_some_feature(melons, A)
3     max_ent= Gainratio(melons, new_A[0])
4     for i in range(len(new_A)):
5         ent_temp = Gainratio(melons, new_A[i])
6         if ent_temp[0]>max_ent[0]:
7             max_ent = ent_temp
8     return max_ent

5.結果展示

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM