python實現簡單決策樹(信息增益)——基於周志華的西瓜書數據


數據集如下:

 1 色澤    根蒂    敲聲    紋理    臍部    觸感    好瓜
 2 青綠    蜷縮    濁響    清晰    凹陷    硬滑    是
 3 烏黑    蜷縮    沉悶    清晰    凹陷    硬滑    是
 4 烏黑    蜷縮    濁響    清晰    凹陷    硬滑    是
 5 青綠    蜷縮    沉悶    清晰    凹陷    硬滑    是
 6 淺白    蜷縮    濁響    清晰    凹陷    硬滑    是
 7 青綠    稍蜷    濁響    清晰    稍凹    軟粘    是
 8 烏黑    稍蜷    濁響    稍糊    稍凹    軟粘    是
 9 烏黑    稍蜷    濁響    清晰    稍凹    硬滑    是
10 烏黑    稍蜷    沉悶    稍糊    稍凹    硬滑    否
11 青綠    硬挺    清脆    清晰    平坦    軟粘    否
12 淺白    硬挺    清脆    模糊    平坦    硬滑    否
13 淺白    蜷縮    濁響    模糊    平坦    軟粘    否
14 青綠    稍蜷    濁響    稍糊    凹陷    硬滑    否
15 淺白    稍蜷    沉悶    稍糊    凹陷    硬滑    否
16 烏黑    稍蜷    濁響    清晰    稍凹    軟粘    否
17 淺白    蜷縮    濁響    模糊    平坦    硬滑    否
18 青綠    蜷縮    沉悶    稍糊    稍凹    硬滑    否

基於信息增益的ID3決策樹的原理這里不再贅述,讀者如果不明白可參考西瓜書對這部分內容的講解。

python實現代碼如下:

  1 from math import log2
  2 import pandas as pd
  3 import matplotlib.pyplot as plt
  4 from matplotlib.font_manager import FontProperties
  5 
  6 # 統計label出現次數
  7 def get_counts(data):
  8     total = len(data)
  9     results = {}
 10     for d in data:
 11         results[d[-1]] = results.get(d[-1], 0) + 1
 12     return results, total
 13 
 14 # 計算信息熵
 15 def calcu_entropy(data):
 16     results, total = get_counts(data)
 17     ent = sum([-1.0*v/total*log2(v/total) for v in results.values()])
 18     return ent
 19 
 20 # 計算每個feature的信息增益
 21 def calcu_each_gain(column, update_data):
 22     total = len(column)
 23     grouped = update_data.iloc[:, -1].groupby(by=column)
 24     temp = sum([len(g[1])/total*calcu_entropy(g[1]) for g in list(grouped)])
 25     return calcu_entropy(update_data.iloc[:, -1]) - temp
 26 
 27 # 獲取最大的信息增益的feature
 28 def get_max_gain(temp_data):
 29     columns_entropy = [(col, calcu_each_gain(temp_data[col], temp_data)) for col in temp_data.iloc[:, :-1]]
 30     columns_entropy = sorted(columns_entropy, key=lambda f: f[1], reverse=True)
 31     return columns_entropy[0]
 32 
 33 # 去掉數據中已存在的列屬性內容
 34 def drop_exist_feature(data, best_feature):
 35     attr = pd.unique(data[best_feature])
 36     new_data = [(nd, data[data[best_feature] == nd]) for nd in attr]
 37     new_data = [(n[0], n[1].drop([best_feature], axis=1)) for n in new_data]
 38     return new_data
 39 
 40 # 獲得出現最多的label
 41 def get_most_label(label_list):
 42     label_dict = {}
 43     for l in label_list:
 44         label_dict[l] = label_dict.get(l, 0) + 1
 45     sorted_label = sorted(label_dict.items(), key=lambda ll: ll[1], reverse=True)
 46     return sorted_label[0][0]
 47 
 48 # 創建決策樹
 49 def create_tree(data_set, column_count):
 50     label_list = data_set.iloc[:, -1]
 51     if len(pd.unique(label_list)) == 1:
 52         return label_list.values[0]
 53     if all([len(pd.unique(data_set[i])) ==1 for i in data_set.iloc[:, :-1].columns]):
 54         return get_most_label(label_list)
 55     best_attr = get_max_gain(data_set)[0]
 56     tree = {best_attr: {}}
 57     exist_attr = pd.unique(data_set[best_attr])
 58     if len(exist_attr) != len(column_count[best_attr]):
 59         no_exist_attr = set(column_count[best_attr]) - set(exist_attr)
 60         for nea in no_exist_attr:
 61             tree[best_attr][nea] = get_most_label(label_list)
 62     for item in drop_exist_feature(data_set, best_attr):
 63         tree[best_attr][item[0]] = create_tree(item[1], column_count)
 64     return tree
 65 
 66 # 決策樹繪制基本參考《機器學習實戰》書內的代碼以及博客:http://blog.csdn.net/c406495762/article/details/76262487
 67 # 獲取樹的葉子節點數目
 68 def get_num_leafs(decision_tree):
 69     num_leafs = 0
 70     first_str = next(iter(decision_tree))
 71     second_dict = decision_tree[first_str]
 72     for k in second_dict.keys():
 73         if isinstance(second_dict[k], dict):
 74             num_leafs += get_num_leafs(second_dict[k])
 75         else:
 76             num_leafs += 1
 77     return num_leafs
 78 
 79 # 獲取樹的深度
 80 def get_tree_depth(decision_tree):
 81     max_depth = 0
 82     first_str = next(iter(decision_tree))
 83     second_dict = decision_tree[first_str]
 84     for k in second_dict.keys():
 85         if isinstance(second_dict[k], dict):
 86             this_depth = 1 + get_tree_depth(second_dict[k])
 87         else:
 88             this_depth = 1
 89         if this_depth > max_depth:
 90             max_depth = this_depth
 91     return max_depth
 92 
 93 # 繪制節點
 94 def plot_node(node_txt, center_pt, parent_pt, node_type):
 95     arrow_args = dict(arrowstyle='<-')
 96     font = FontProperties(fname=r'C:\Windows\Fonts\STXINGKA.TTF', size=15)
 97     create_plot.ax1.annotate(node_txt, xy=parent_pt,  xycoords='axes fraction', xytext=center_pt,
 98                             textcoords='axes fraction', va="center", ha="center", bbox=node_type,
 99                             arrowprops=arrow_args, FontProperties=font)
100 
101 # 標注划分屬性
102 def plot_mid_text(cntr_pt, parent_pt, txt_str):
103     font = FontProperties(fname=r'C:\Windows\Fonts\MSYH.TTC', size=10)
104     x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
105     y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
106     create_plot.ax1.text(x_mid, y_mid, txt_str, va="center", ha="center", color='red', FontProperties=font)
107 
108 # 繪制決策樹
109 def plot_tree(decision_tree, parent_pt, node_txt):
110     d_node = dict(boxstyle="sawtooth", fc="0.8")
111     leaf_node = dict(boxstyle="round4", fc='0.8')
112     num_leafs = get_num_leafs(decision_tree)
113     first_str = next(iter(decision_tree))
114     cntr_pt = (plot_tree.xoff + (1.0 +float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yoff)
115     plot_mid_text(cntr_pt, parent_pt, node_txt)
116     plot_node(first_str, cntr_pt, parent_pt, d_node)
117     second_dict = decision_tree[first_str]
118     plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totalD
119     for k in second_dict.keys():
120         if isinstance(second_dict[k], dict):
121             plot_tree(second_dict[k], cntr_pt, k)
122         else:
123             plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalW
124             plot_node(second_dict[k], (plot_tree.xoff, plot_tree.yoff), cntr_pt, leaf_node)
125             plot_mid_text((plot_tree.xoff, plot_tree.yoff), cntr_pt, k)
126     plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totalD
127 
128 def create_plot(dtree):
129     fig = plt.figure(1, facecolor='white')
130     fig.clf()
131     axprops = dict(xticks=[], yticks=[])
132     create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
133     plot_tree.totalW = float(get_num_leafs(dtree))
134     plot_tree.totalD = float(get_tree_depth(dtree))
135     plot_tree.xoff = -0.5/plot_tree.totalW
136     plot_tree.yoff = 1.0
137     plot_tree(dtree, (0.5, 1.0), '')
138     plt.show()
139 
140 if __name__ == '__main__':
141     my_data = pd.read_csv('./watermelon2.0.csv', encoding='gbk')
142     column_count = dict([(ds, list(pd.unique(my_data[ds]))) for ds in my_data.iloc[:, :-1].columns])
143     d_tree = create_tree(my_data, column_count)
144     create_plot(d_tree)

繪制的決策樹如下:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM