決策樹--ID3 算法(一)


Contents

 

     1. 決策樹的基本認識

     2. ID3算法介紹

     3. 信息熵與信息增益

     4. ID3算法的C++實現

 

 

1. 決策樹的基本認識

 

   決策樹是一種依托決策而建立起來的一種樹。在機器學習中,決策樹是一種預測模型,代表的是一種對

   象屬性與對象值之間的一種映射關系,每一個節點代表某個對象,樹中的每一個分叉路徑代表某個可能

   的屬性值,而每一個葉子節點則對應從根節點到該葉子節點所經歷的路徑所表示的對象的值。決策樹僅

   有單一輸出,如果有多個輸出,可以分別建立獨立的決策樹以處理不同的輸出。接下來講解ID3算法。

 

 

2. ID3算法介紹

 

   ID3算法是決策樹的一種,它是基於奧卡姆剃刀原理的,即用盡量用較少的東西做更多的事。ID3算法

   即Iterative Dichotomiser 3迭代二叉樹3代,是Ross Quinlan發明的一種決策樹算法,這個

   算法的基礎就是上面提到的奧卡姆剃刀原理,越是小型的決策樹越優於大的決策樹,盡管如此,也不總

   是生成最小的樹型結構,而是一個啟發式算法。

 

   在信息論中,期望信息越小,那么信息增益就越大,從而純度就越高。ID3算法的核心思想就是以信息

   增益來度量屬性的選擇,選擇分裂后信息增益最大的屬性進行分裂。該算法采用自頂向下的貪婪搜索遍

   歷可能的決策空間。

 

 

3. 信息熵與信息增益

 

   在信息增益中,重要性的衡量標准就是看特征能夠為分類系統帶來多少信息,帶來的信息越多,該特征越

   重要。在認識信息增益之前,先來看看信息熵的定義

 

   這個概念最早起源於物理學,在物理學中是用來度量一個熱力學系統的無序程度,而在信息學里面,熵

   是對不確定性的度量。在1948年,香農引入了信息熵,將其定義為離散隨機事件出現的概率,一個系統越

   是有序,信息熵就越低,反之一個系統越是混亂,它的信息熵就越高。所以信息熵可以被認為是系統有序

   化程度的一個度量。

 

   假如一個隨機變量的取值為,每一種取到的概率分別是,那么

   的熵定義為

 

             

 

   意思是一個變量的變化情況可能越多,那么它攜帶的信息量就越大。

 

   對於分類系統來說,類別是變量,它的取值是,而每一個類別出現的概率分別是

 

             

 

   而這里的就是類別的總數,此時分類系統的熵就可以表示為

 

             

 

   以上就是信息熵的定義,接下來介紹信息增益

 

   信息增益是針對一個一個特征而言的,就是看一個特征,系統有它和沒有它時的信息量各是多少,兩者

   的差值就是這個特征給系統帶來的信息量,即信息增益

 

   接下來以天氣預報的例子來說明。下面是描述天氣數據表,學習目標是play或者not play

 

   

 

   可以看出,一共14個樣例,包括9個正例和5個負例。那么當前信息的熵計算如下

 

   

 

   在決策樹分類問題中,信息增益就是決策樹在進行屬性選擇划分前和划分后信息的差值。假設利用

   屬性Outlook來分類,那么如下圖

 

   

 

      划分后,數據被分為三部分了,那么各個分支的信息熵計算如下

 

       

 

       那么划分后的信息熵為

 

        

 

        代表在特征屬性的條件下樣本的條件熵。那么最終得到特征屬性帶來的信息增益為

 

        

 

   信息增益的計算公式如下

 

   

 

   其中為全部樣本集合,是屬性所有取值的集合,的其中一個屬性值,中屬性

   值為的樣例集合,中所含樣例數。

 

   在決策樹的每一個非葉子結點划分之前,先計算每一個屬性所帶來的信息增益,選擇最大信息增益的屬性來划

   分,因為信息增益越大,區分樣本的能力就越強,越具有代表性,很顯然這是一種自頂向下的貪心策略。以上

   就是ID3算法的核心思想。

3.決策樹停止的條件

    如果發生以下的情況,決策樹將停止分割

    1.改群數據的每一筆數據已經歸類到每一類數據中,即數據已經不能繼續在分。

    2.該群數據已經找不到新的屬性進行節點分割

    3.該群數據沒有任何未處理的數據

 

 

 

4. ID3算法的C++實現

 

   接下來開始用C++實現ID3算法,包括以下文件

 

   

 

ID3.h

[cpp]  view plain  copy
 
  在CODE上查看代碼片 派生到我的代碼片
  1. #ifndef _ID3_H_  
  2. #define _ID3_H_  
  3.    
  4. #include <utility>  
  5. #include <list>  
  6. #include <map>  
  7.    
  8. #define Type int   //樣本數據類型  
  9.    
  10. #define   Map1        std::map< int, Type >    //定義一維map  
  11. #define   Map2        std::map< int, Map1 >    //定義二維map  
  12. #define   Map3        std::map< int, Map2 >    //定義三維map  
  13. #define   Pair        std::pair<int, Type>  
  14. #define   List        std::list< Pair >        //一維list  
  15. #define   SampleSpace std::list< List >        //二維list 用於存放樣本數據  
  16. #define   Child       std::map< int, Node* >   //定義后繼節點集合  
  17. #define   CI          const_iterator  
  18.    
  19. /* 
  20.  *   在ID3算法中,用二維鏈表存放樣本,結構為list< list< pair<int, int> > >,簡記為SampleSpace,取名樣本空間 
  21.  *   樣本數據從根節點開始往下遍歷。每一個節點的定義如下結構體 
  22.  */  
  23.    
  24. struct Node  
  25. {  
  26.     int index;                    //當前節點樣本最大增益對應第index個屬性,根據這個進行分類的  
  27.     int type;                     //當前節點的類型  
  28.     Child next;                   //當前節點的后繼節點集合  
  29.     SampleSpace sample;           //未分類的樣本集合  
  30. };  
  31.    
  32. class ID3{  
  33.    
  34. public:  
  35.    
  36.     ID3(int );      
  37.     ~ID3();  
  38.    
  39.     void PushData(const Type*, const Type);   //將樣本數據Push給二維鏈表  
  40.     void Build();                             //構建決策樹  
  41.     int  Match(const Type*);                  //根據新的樣本預測結果  
  42.     void Print();                             //打印決策樹的節點的值  
  43.    
  44. private:  
  45.    
  46.     void   _clear(Node*);  
  47.     void   _build(Node*, int);  
  48.     int    _match(const int*, Node*);  
  49.     void   _work(Node*);  
  50.     double _entropy(const Map1&, double);  
  51.     int    _get_max_gain(const SampleSpace&);  
  52.     void   _split(Node*, int);  
  53.     void   _get_data(const SampleSpace&, Map1&, Map2&, Map3&);  
  54.     double _info_gain(Map1&, Map2&, doubledouble);  
  55.     int    _same_class(const SampleSpace&);  
  56.     void   _print(Node*);  
  57.    
  58. private:  
  59.    
  60.     int dimension;  
  61.     Node *root;  
  62. };  
  63.    
  64. #endif // _ID3_H_  


ID3.cpp

[cpp]  view plain  copy
 
  在CODE上查看代碼片 派生到我的代碼片
  1. #include <iostream>  
  2. #include <cassert>  
  3. #include <cmath>  
  4.    
  5. #include "ID3.h"  
  6.    
  7. using namespace std;  
  8.    
  9. //初始化ID3的數據成員  
  10. ID3::ID3(int dimension)  
  11. {  
  12.     this->dimension = dimension;  
  13.    
  14.     root = new Node();  
  15.     root->index = -1;  
  16.     root->type = -1;  
  17.     root->next.clear();  
  18.     root->sample.clear();  
  19. }  
  20.    
  21. //清空整個決策樹  
  22. ID3::~ID3()  
  23. {  
  24.     this->dimension = 0;  
  25.     _clear(root);  
  26. }  
  27.    
  28. //x為dimension維的屬性向量,y為向量x對應的值  
  29. void ID3::PushData(const Type *x, const Type y)  
  30. {  
  31.     List single;  
  32.     single.clear();  
  33.     for(int i = 0; i < dimension; i++)  
  34.         single.push_back(make_pair(i + 1, x[i]));  
  35.     single.push_back(make_pair(0, y));  
  36.     root->sample.push_back(single);  
  37. }  
  38.    
  39. void ID3::_clear(Node *node)  
  40. {  
  41.     Child &next = node->next;  
  42.     Child::iterator it;  
  43.     for(it = next.begin(); it != next.end(); it++)  
  44.         _clear(it->second);  
  45.     next.clear();  
  46.     delete node;  
  47. }  
  48.    
  49. void ID3::Build()  
  50. {  
  51.     _build(root, dimension);  
  52. }  
  53.    
  54. void ID3::_build(Node *node, int dimension)  
  55. {  
  56.     //獲取當前節點未分類的樣本數據  
  57.     SampleSpace &sample = node->sample;  
  58.    
  59.     //判斷當前所有樣本是否是同一類,如果不是則返回-1  
  60.     int y = _same_class(sample);  
  61.    
  62.     //如果所有樣本是屬於同一類  
  63.     if(y >= 0)  
  64.     {  
  65.         node->index = -1;  
  66.         node->type = y;  
  67.         return;  
  68.     }  
  69.    
  70.     //在_max_gain()函數中計算出當前節點的最大增益對應的屬性,並根據這個屬性對數據進行划分  
  71.     _work(node);  
  72.    
  73.     //Split完成后清空當前節點的所有數據,以免占用太多內存  
  74.     sample.clear();  
  75.    
  76.     Child &next = node->next;  
  77.     for(Child::iterator it = next.begin(); it != next.end(); it++)  
  78.         _build(it->second, dimension - 1);  
  79. }  
  80.    
  81. //判斷當前所有樣本是否是同一類,如果不是則返回-1  
  82. int ID3::_same_class(const SampleSpace &ss)  
  83. {  
  84.     //取出當前樣本數據的一個Sample  
  85.     const List &f = ss.front();  
  86.    
  87.     //如果沒有x屬性,而只有y,直接返回y  
  88.     if(f.size() == 1)  
  89.         return f.front().second;  
  90.    
  91.     Type y = 0;  
  92.     //取出第一個樣本數據y的結果值  
  93.     for(List::CI it = f.begin(); it != f.end(); it++)  
  94.     {  
  95.         if(!it->first)  
  96.         {  
  97.             y = it->second;  
  98.             break;  
  99.         }  
  100.     }  
  101.    
  102.     //接下來進行判斷,因為list是有序的,所以從前往后遍歷,發現有一對不一樣,則所有樣本不是同一類  
  103.     for(SampleSpace::CI it = ss.begin(); it != ss.end(); it++)  
  104.     {  
  105.         const List &single = *it;  
  106.         for(List::CI i = single.begin(); i != single.end(); i++)  
  107.         {  
  108.             if(!i->first)  
  109.             {  
  110.                 if(y != i->second)  
  111.                     return -1;         //發現不是同一類則返回-1  
  112.                 else  
  113.                     break;  
  114.             }  
  115.         }  
  116.     }  
  117.     return y;     //比較完所有樣本的輸出值y后,發現是同一類,返回y值。  
  118. }  
  119.    
  120. void ID3::_work(Node *node)  
  121. {  
  122.     int mai = _get_max_gain(node->sample);  
  123.     assert(mai >= 0);  
  124.     node->index = mai;  
  125.     _split(node, mai);  
  126. }  
  127.    
  128. //獲取最大的信息增益對應的屬性  
  129. int ID3::_get_max_gain(const SampleSpace &ss)  
  130. {  
  131.     Map1 y;  
  132.     Map2 x;  
  133.     Map3 xy;  
  134.    
  135.     _get_data(ss, y, x, xy);  
  136.     double s = ss.size();  
  137.     double entropy = _entropy(y, s);   //計算熵值  
  138.    
  139.     int mai = -1;  
  140.     double mag = -1;  
  141.    
  142.     for(Map2::iterator it = x.begin(); it != x.end(); it++)  
  143.     {  
  144.         double g = _info_gain(it->second, xy[it->first], s, entropy);    //計算信息增益值  
  145.         if(g > mag)  
  146.         {  
  147.             mag = g;  
  148.             mai = it->first;  
  149.         }  
  150.     }  
  151.    
  152.     if(!x.size() && !xy.size() && y.size())   //如果只有y數據  
  153.         return 0;  
  154.     return mai;  
  155. }  
  156.    
  157. //獲取數據,提取出所有樣本的y值,x[]屬性值,以及屬性值和結果值xy。  
  158. void ID3::_get_data(const SampleSpace &ss, Map1 &y, Map2 &x, Map3 &xy)  
  159. {  
  160.     for(SampleSpace::CI it = ss.begin(); it != ss.end(); it++)  
  161.     {  
  162.     int c = 0;  
  163.         const List &v = *it;  
  164.         for(List::CI p = v.begin(); p != v.end(); p++)  
  165.         {  
  166.             if(!p->first)  
  167.             {  
  168.                 c = p->second;  
  169.                 break;  
  170.             }  
  171.         }  
  172.         ++y[c];  
  173.         for(List::CI p = v.begin(); p != v.end(); p++)  
  174.         {  
  175.             if(p->first)  
  176.             {  
  177.                 ++x[p->first][p->second];  
  178.                 ++xy[p->first][p->second][c];  
  179.             }  
  180.         }  
  181.     }  
  182. }  
  183.    
  184. //計算熵值  
  185. double ID3::_entropy(const Map1 &x, double s)  
  186. {  
  187.     double ans = 0;  
  188.     for(Map1::CI it = x.begin(); it != x.end(); it++)  
  189.     {  
  190.         double t = it->second / s;  
  191.         ans += t * log2(t);  
  192.     }  
  193.     return -ans;  
  194. }  
  195.    
  196. //計算信息增益  
  197. double ID3::_info_gain(Map1 &att_val, Map2 &val_cls, double s, double entropy)  
  198. {  
  199.     double gain = entropy;  
  200.     for(Map1::CI it = att_val.begin(); it != att_val.end(); it++)  
  201.     {  
  202.         double r = it->second / s;  
  203.         double e = _entropy(val_cls[it->first], it->second);  
  204.         gain -= r * e;  
  205.     }  
  206.     return gain;  
  207. }  
  208.    
  209. //對當前節點的sample進行划分  
  210. void ID3::_split(Node *node, int idx)  
  211. {  
  212.     Child &next = node->next;  
  213.     SampleSpace &sample = node->sample;  
  214.    
  215.     for(SampleSpace::iterator it = sample.begin(); it != sample.end(); it++)  
  216.     {  
  217.         List &v = *it;  
  218.         for(List::iterator p = v.begin(); p != v.end(); p++)  
  219.         {  
  220.             if(p->first == idx)  
  221.             {  
  222.                 Node *tmp = next[p->second];  
  223.                 if(!tmp)  
  224.                 {  
  225.                     tmp = new Node();  
  226.                     tmp->index = -1;  
  227.                     tmp->type = -1;  
  228.                     next[p->second] = tmp;  
  229.                 }  
  230.                 v.erase(p);  
  231.                 tmp->sample.push_back(v);  
  232.                 break;  
  233.             }  
  234.         }  
  235.     }  
  236. }  
  237.    
  238. int ID3::Match(const Type *x)  
  239. {  
  240.     return _match(x, root);  
  241. }    
  242.    
  243. int ID3::_match(const Type *v, Node *node)  
  244. {  
  245.     if(node->index < 0)  
  246.         return node->type;  
  247.    
  248.     Child &next = node->next;  
  249.     Child::iterator p = next.find(v[node->index - 1]);  
  250.     if(p == next.end())  
  251.         return -1;  
  252.    
  253.     return _match(v, p->second);  
  254. }  
  255.    
  256. void ID3::Print()  
  257. {  
  258.     _print(root);  
  259. }  
  260.    
  261. void ID3::_print(Node *node)  
  262. {  
  263.     cout << "Index    = " << node->index << endl;  
  264.     cout << "Type     = " << node->type << endl;  
  265.     cout << "NextSize = " << node->next.size() << endl;  
  266.     cout << endl;  
  267.    
  268.     Child &next = node->next;  
  269.     Child::iterator p;  
  270.     for(p = next.begin(); p != next.end(); ++p)  
  271.         _print(p->second);  
  272. }  

main.cpp

[cpp]  view plain  copy
 
  在CODE上查看代碼片 派生到我的代碼片
  1. #include <iostream>  
  2. #include "ID3.h"  
  3.    
  4. using namespace std;  
  5.    
  6. enum outlook {SUNNY, OVERCAST, RAIN };  
  7. enum temp    {HOT,   MILD,     COOL };  
  8. enum hum     {HIGH,  NORMAL         };  
  9. enum windy   {WEAK,  STRONG         };  
  10.    
  11. int samples[14][4] =  
  12. {  
  13.     {SUNNY   ,       HOT ,      HIGH  ,       WEAK  },  
  14.     {SUNNY   ,       HOT ,      HIGH  ,       STRONG},  
  15.     {OVERCAST,       HOT ,      HIGH  ,       WEAK  },  
  16.     {RAIN    ,       MILD,      HIGH  ,       WEAK  },  
  17.     {RAIN    ,       COOL,      NORMAL,       WEAK  },  
  18.     {RAIN    ,       COOL,      NORMAL,       STRONG},  
  19.     {OVERCAST,       COOL,      NORMAL,       STRONG},  
  20.     {SUNNY   ,       MILD,      HIGH  ,       WEAK  },  
  21.     {SUNNY   ,       COOL,      NORMAL,       WEAK  },  
  22.     {RAIN    ,       MILD,      NORMAL,       WEAK  },  
  23.     {SUNNY   ,       MILD,      NORMAL,       STRONG},  
  24.     {OVERCAST,       MILD,      HIGH  ,       STRONG},  
  25.     {OVERCAST,       HOT ,      NORMAL,       WEAK  },  
  26.     {RAIN    ,       MILD,      HIGH  ,       STRONG}  
  27. };  
  28.    
  29. int main()  
  30. {  
  31.     ID3 Tree(4);  
  32.     Tree.PushData((int *)&samples[0], 0);  
  33.     Tree.PushData((int *)&samples[1], 0);  
  34.     Tree.PushData((int *)&samples[2], 1);  
  35.     Tree.PushData((int *)&samples[3], 1);  
  36.     Tree.PushData((int *)&samples[4], 1);  
  37.     Tree.PushData((int *)&samples[5], 0);  
  38.     Tree.PushData((int *)&samples[6], 1);  
  39.     Tree.PushData((int *)&samples[7], 0);  
  40.     Tree.PushData((int *)&samples[8], 1);  
  41.     Tree.PushData((int *)&samples[9], 1);  
  42.     Tree.PushData((int *)&samples[10], 1);  
  43.     Tree.PushData((int *)&samples[11], 1);  
  44.     Tree.PushData((int *)&samples[12], 1);  
  45.     Tree.PushData((int *)&samples[13], 0);  
  46.    
  47.     Tree.Build();  
  48.     Tree.Print();  
  49.     cout << endl;  
  50.     for(int i = 0; i < 14; ++i)  
  51.         cout << "predict value :    " <<Tree.Match( (int *)&samples[i] ) << endl;  
  52.     return 0;  
  53. }  

Makefile

[cpp]  view plain  copy
 
  在CODE上查看代碼片 派生到我的代碼片
  1. Test: main.cpp ID3.h ID3.cpp  
  2.     g++ -o Test ID3.cpp main.cpp  
  3.    
  4. clean:  
  5.     rm Test  

 

 






免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM