Contents
1. 決策樹的基本認識
2. ID3算法介紹
3. 信息熵與信息增益
4. ID3算法的C++實現
1. 決策樹的基本認識
決策樹是一種依托決策而建立起來的一種樹。在機器學習中,決策樹是一種預測模型,代表的是一種對
象屬性與對象值之間的一種映射關系,每一個節點代表某個對象,樹中的每一個分叉路徑代表某個可能
的屬性值,而每一個葉子節點則對應從根節點到該葉子節點所經歷的路徑所表示的對象的值。決策樹僅
有單一輸出,如果有多個輸出,可以分別建立獨立的決策樹以處理不同的輸出。接下來講解ID3算法。
2. ID3算法介紹
ID3算法是決策樹的一種,它是基於奧卡姆剃刀原理的,即用盡量用較少的東西做更多的事。ID3算法,
即Iterative Dichotomiser 3,迭代二叉樹3代,是Ross Quinlan發明的一種決策樹算法,這個
算法的基礎就是上面提到的奧卡姆剃刀原理,越是小型的決策樹越優於大的決策樹,盡管如此,也不總
是生成最小的樹型結構,而是一個啟發式算法。
在信息論中,期望信息越小,那么信息增益就越大,從而純度就越高。ID3算法的核心思想就是以信息
增益來度量屬性的選擇,選擇分裂后信息增益最大的屬性進行分裂。該算法采用自頂向下的貪婪搜索遍
歷可能的決策空間。
3. 信息熵與信息增益
在信息增益中,重要性的衡量標准就是看特征能夠為分類系統帶來多少信息,帶來的信息越多,該特征越
重要。在認識信息增益之前,先來看看信息熵的定義
熵這個概念最早起源於物理學,在物理學中是用來度量一個熱力學系統的無序程度,而在信息學里面,熵
是對不確定性的度量。在1948年,香農引入了信息熵,將其定義為離散隨機事件出現的概率,一個系統越
是有序,信息熵就越低,反之一個系統越是混亂,它的信息熵就越高。所以信息熵可以被認為是系統有序
化程度的一個度量。
假如一個隨機變量
的取值為
,每一種取到的概率分別是
,那么
的熵定義為

意思是一個變量的變化情況可能越多,那么它攜帶的信息量就越大。
對於分類系統來說,類別
是變量,它的取值是
,而每一個類別出現的概率分別是

而這里的
就是類別的總數,此時分類系統的熵就可以表示為

以上就是信息熵的定義,接下來介紹信息增益。
信息增益是針對一個一個特征而言的,就是看一個特征
,系統有它和沒有它時的信息量各是多少,兩者
的差值就是這個特征給系統帶來的信息量,即信息增益。
接下來以天氣預報的例子來說明。下面是描述天氣數據表,學習目標是play或者not play。

可以看出,一共14個樣例,包括9個正例和5個負例。那么當前信息的熵計算如下

在決策樹分類問題中,信息增益就是決策樹在進行屬性選擇划分前和划分后信息的差值。假設利用
屬性Outlook來分類,那么如下圖

划分后,數據被分為三部分了,那么各個分支的信息熵計算如下

那么划分后的信息熵為

代表在特征屬性
的條件下樣本的條件熵。那么最終得到特征屬性
帶來的信息增益為

信息增益的計算公式如下

其中
為全部樣本集合,
是屬性
所有取值的集合,
是
的其中一個屬性值,
是
中屬性
的
值為
的樣例集合,
為
中所含樣例數。
在決策樹的每一個非葉子結點划分之前,先計算每一個屬性所帶來的信息增益,選擇最大信息增益的屬性來划
分,因為信息增益越大,區分樣本的能力就越強,越具有代表性,很顯然這是一種自頂向下的貪心策略。以上
就是ID3算法的核心思想。
3.決策樹停止的條件
如果發生以下的情況,決策樹將停止分割
1.改群數據的每一筆數據已經歸類到每一類數據中,即數據已經不能繼續在分。
2.該群數據已經找不到新的屬性進行節點分割
3.該群數據沒有任何未處理的數據
4. ID3算法的C++實現
接下來開始用C++實現ID3算法,包括以下文件

ID3.h
- #ifndef _ID3_H_
- #define _ID3_H_
-
- #include <utility>
- #include <list>
- #include <map>
-
- #define Type int //樣本數據類型
-
- #define Map1 std::map< int, Type > //定義一維map
- #define Map2 std::map< int, Map1 > //定義二維map
- #define Map3 std::map< int, Map2 > //定義三維map
- #define Pair std::pair<int, Type>
- #define List std::list< Pair > //一維list
- #define SampleSpace std::list< List > //二維list 用於存放樣本數據
- #define Child std::map< int, Node* > //定義后繼節點集合
- #define CI const_iterator
-
-
-
-
-
-
- struct Node
- {
- int index;
- int type;
- Child next;
- SampleSpace sample;
- };
-
- class ID3{
-
- public:
-
- ID3(int );
- ~ID3();
-
- void PushData(const Type*, const Type);
- void Build();
- int Match(const Type*);
- void Print();
-
- private:
-
- void _clear(Node*);
- void _build(Node*, int);
- int _match(const int*, Node*);
- void _work(Node*);
- double _entropy(const Map1&, double);
- int _get_max_gain(const SampleSpace&);
- void _split(Node*, int);
- void _get_data(const SampleSpace&, Map1&, Map2&, Map3&);
- double _info_gain(Map1&, Map2&, double, double);
- int _same_class(const SampleSpace&);
- void _print(Node*);
-
- private:
-
- int dimension;
- Node *root;
- };
-
- #endif // _ID3_H_
ID3.cpp
- #include <iostream>
- #include <cassert>
- #include <cmath>
-
- #include "ID3.h"
-
- using namespace std;
-
-
- ID3::ID3(int dimension)
- {
- this->dimension = dimension;
-
- root = new Node();
- root->index = -1;
- root->type = -1;
- root->next.clear();
- root->sample.clear();
- }
-
-
- ID3::~ID3()
- {
- this->dimension = 0;
- _clear(root);
- }
-
-
- void ID3::PushData(const Type *x, const Type y)
- {
- List single;
- single.clear();
- for(int i = 0; i < dimension; i++)
- single.push_back(make_pair(i + 1, x[i]));
- single.push_back(make_pair(0, y));
- root->sample.push_back(single);
- }
-
- void ID3::_clear(Node *node)
- {
- Child &next = node->next;
- Child::iterator it;
- for(it = next.begin(); it != next.end(); it++)
- _clear(it->second);
- next.clear();
- delete node;
- }
-
- void ID3::Build()
- {
- _build(root, dimension);
- }
-
- void ID3::_build(Node *node, int dimension)
- {
-
- SampleSpace &sample = node->sample;
-
-
- int y = _same_class(sample);
-
-
- if(y >= 0)
- {
- node->index = -1;
- node->type = y;
- return;
- }
-
-
- _work(node);
-
-
- sample.clear();
-
- Child &next = node->next;
- for(Child::iterator it = next.begin(); it != next.end(); it++)
- _build(it->second, dimension - 1);
- }
-
-
- int ID3::_same_class(const SampleSpace &ss)
- {
-
- const List &f = ss.front();
-
-
- if(f.size() == 1)
- return f.front().second;
-
- Type y = 0;
-
- for(List::CI it = f.begin(); it != f.end(); it++)
- {
- if(!it->first)
- {
- y = it->second;
- break;
- }
- }
-
-
- for(SampleSpace::CI it = ss.begin(); it != ss.end(); it++)
- {
- const List &single = *it;
- for(List::CI i = single.begin(); i != single.end(); i++)
- {
- if(!i->first)
- {
- if(y != i->second)
- return -1;
- else
- break;
- }
- }
- }
- return y;
- }
-
- void ID3::_work(Node *node)
- {
- int mai = _get_max_gain(node->sample);
- assert(mai >= 0);
- node->index = mai;
- _split(node, mai);
- }
-
-
- int ID3::_get_max_gain(const SampleSpace &ss)
- {
- Map1 y;
- Map2 x;
- Map3 xy;
-
- _get_data(ss, y, x, xy);
- double s = ss.size();
- double entropy = _entropy(y, s);
-
- int mai = -1;
- double mag = -1;
-
- for(Map2::iterator it = x.begin(); it != x.end(); it++)
- {
- double g = _info_gain(it->second, xy[it->first], s, entropy);
- if(g > mag)
- {
- mag = g;
- mai = it->first;
- }
- }
-
- if(!x.size() && !xy.size() && y.size())
- return 0;
- return mai;
- }
-
-
- void ID3::_get_data(const SampleSpace &ss, Map1 &y, Map2 &x, Map3 &xy)
- {
- for(SampleSpace::CI it = ss.begin(); it != ss.end(); it++)
- {
- int c = 0;
- const List &v = *it;
- for(List::CI p = v.begin(); p != v.end(); p++)
- {
- if(!p->first)
- {
- c = p->second;
- break;
- }
- }
- ++y[c];
- for(List::CI p = v.begin(); p != v.end(); p++)
- {
- if(p->first)
- {
- ++x[p->first][p->second];
- ++xy[p->first][p->second][c];
- }
- }
- }
- }
-
-
- double ID3::_entropy(const Map1 &x, double s)
- {
- double ans = 0;
- for(Map1::CI it = x.begin(); it != x.end(); it++)
- {
- double t = it->second / s;
- ans += t * log2(t);
- }
- return -ans;
- }
-
-
- double ID3::_info_gain(Map1 &att_val, Map2 &val_cls, double s, double entropy)
- {
- double gain = entropy;
- for(Map1::CI it = att_val.begin(); it != att_val.end(); it++)
- {
- double r = it->second / s;
- double e = _entropy(val_cls[it->first], it->second);
- gain -= r * e;
- }
- return gain;
- }
-
-
- void ID3::_split(Node *node, int idx)
- {
- Child &next = node->next;
- SampleSpace &sample = node->sample;
-
- for(SampleSpace::iterator it = sample.begin(); it != sample.end(); it++)
- {
- List &v = *it;
- for(List::iterator p = v.begin(); p != v.end(); p++)
- {
- if(p->first == idx)
- {
- Node *tmp = next[p->second];
- if(!tmp)
- {
- tmp = new Node();
- tmp->index = -1;
- tmp->type = -1;
- next[p->second] = tmp;
- }
- v.erase(p);
- tmp->sample.push_back(v);
- break;
- }
- }
- }
- }
-
- int ID3::Match(const Type *x)
- {
- return _match(x, root);
- }
-
- int ID3::_match(const Type *v, Node *node)
- {
- if(node->index < 0)
- return node->type;
-
- Child &next = node->next;
- Child::iterator p = next.find(v[node->index - 1]);
- if(p == next.end())
- return -1;
-
- return _match(v, p->second);
- }
-
- void ID3::Print()
- {
- _print(root);
- }
-
- void ID3::_print(Node *node)
- {
- cout << "Index = " << node->index << endl;
- cout << "Type = " << node->type << endl;
- cout << "NextSize = " << node->next.size() << endl;
- cout << endl;
-
- Child &next = node->next;
- Child::iterator p;
- for(p = next.begin(); p != next.end(); ++p)
- _print(p->second);
- }
main.cpp
- #include <iostream>
- #include "ID3.h"
-
- using namespace std;
-
- enum outlook {SUNNY, OVERCAST, RAIN };
- enum temp {HOT, MILD, COOL };
- enum hum {HIGH, NORMAL };
- enum windy {WEAK, STRONG };
-
- int samples[14][4] =
- {
- {SUNNY , HOT , HIGH , WEAK },
- {SUNNY , HOT , HIGH , STRONG},
- {OVERCAST, HOT , HIGH , WEAK },
- {RAIN , MILD, HIGH , WEAK },
- {RAIN , COOL, NORMAL, WEAK },
- {RAIN , COOL, NORMAL, STRONG},
- {OVERCAST, COOL, NORMAL, STRONG},
- {SUNNY , MILD, HIGH , WEAK },
- {SUNNY , COOL, NORMAL, WEAK },
- {RAIN , MILD, NORMAL, WEAK },
- {SUNNY , MILD, NORMAL, STRONG},
- {OVERCAST, MILD, HIGH , STRONG},
- {OVERCAST, HOT , NORMAL, WEAK },
- {RAIN , MILD, HIGH , STRONG}
- };
-
- int main()
- {
- ID3 Tree(4);
- Tree.PushData((int *)&samples[0], 0);
- Tree.PushData((int *)&samples[1], 0);
- Tree.PushData((int *)&samples[2], 1);
- Tree.PushData((int *)&samples[3], 1);
- Tree.PushData((int *)&samples[4], 1);
- Tree.PushData((int *)&samples[5], 0);
- Tree.PushData((int *)&samples[6], 1);
- Tree.PushData((int *)&samples[7], 0);
- Tree.PushData((int *)&samples[8], 1);
- Tree.PushData((int *)&samples[9], 1);
- Tree.PushData((int *)&samples[10], 1);
- Tree.PushData((int *)&samples[11], 1);
- Tree.PushData((int *)&samples[12], 1);
- Tree.PushData((int *)&samples[13], 0);
-
- Tree.Build();
- Tree.Print();
- cout << endl;
- for(int i = 0; i < 14; ++i)
- cout << "predict value : " <<Tree.Match( (int *)&samples[i] ) << endl;
- return 0;
- }
Makefile
- Test: main.cpp ID3.h ID3.cpp
- g++ -o Test ID3.cpp main.cpp
-
- clean:
- rm Test