ID3算法的核心思想就是以信息增益度量屬性選擇,選擇分裂后信息增益最大的屬性進行分裂。
例子
訓練數據
每一行代表一個數據,前4個元素表示輸入,最后一個是標簽。
traindata=[
5 3 1.6 0.2 1
5 3.4 1.6 0.4 1
5.2 3.5 1.5 0.2 1
5.2 3.4 1.4 0.2 1
4.7 3.2 1.6 0.2 1
4.8 3.1 1.6 0.2 1
5.4 3.4 1.5 0.4 1
5.2 4.1 1.5 0.1 1
5.5 4.2 1.4 0.2 1
4.9 3.1 1.5 0.2 1
5 3.2 1.2 0.2 1
5.5 3.5 1.3 0.2 1
4.9 3.6 1.4 0.1 1
4.4 3 1.3 0.2 1
5.1 3.4 1.5 0.2 1
5 3.5 1.3 0.3 1
4.5 2.3 1.3 0.3 1
4.4 3.2 1.3 0.2 1
5 3.5 1.6 0.6 1
5.1 3.8 1.9 0.4 1
4.8 3 1.4 0.3 1
5.1 3.8 1.6 0.2 1
4.6 3.2 1.4 0.2 1
5.3 3.7 1.5 0.2 1
5 3.3 1.4 0.2 1
6.6 3 4.4 1.4 2
6.8 2.8 4.8 1.4 2
6.7 3 5 1.7 2
6 2.9 4.5 1.5 2
5.7 2.6 3.5 1 2
5.5 2.4 3.8 1.1 2
5.5 2.4 3.7 1 2
5.8 2.7 3.9 1.2 2
6 2.7 5.1 1.6 2
5.4 3 4.5 1.5 2
6 3.4 4.5 1.6 2
6.7 3.1 4.7 1.5 2
6.3 2.3 4.4 1.3 2
5.6 3 4.1 1.3 2
5.5 2.5 4 1.3 2
5.5 2.6 4.4 1.2 2
6.1 3 4.6 1.4 2
5.8 2.6 4 1.2 2
5 2.3 3.3 1 2
5.6 2.7 4.2 1.3 2
5.7 3 4.2 1.2 2
5.7 2.9 4.2 1.3 2
6.2 2.9 4.3 1.3 2
5.1 2.5 3 1.1 2
5.7 2.8 4.1 1.3 2
7.2 3.2 6 1.8 3
6.2 2.8 4.8 1.8 3
6.1 3 4.9 1.8 3
6.4 2.8 5.6 2.1 3
7.2 3 5.8 1.6 3
7.4 2.8 6.1 1.9 3
7.9 3.8 6.4 2 3
6.4 2.8 5.6 2.2 3
6.3 2.8 5.1 1.5 3
6.1 2.6 5.6 1.4 3
7.7 3 6.1 2.3 3
6.3 3.4 5.6 2.4 3
6.4 3.1 5.5 1.8 3
6 3 4.8 1.8 3
6.9 3.1 5.4 2.1 3
6.7 3.1 5.6 2.4 3
6.9 3.1 5.1 2.3 3
5.8 2.7 5.1 1.9 3
6.8 3.2 5.9 2.3 3
6.7 3.3 5.7 2.5 3
6.7 3 5.2 2.3 3
6.3 2.5 5 1.9 3
6.5 3 5.2 2 3
6.2 3.4 5.4 2.3 3
5.9 3 5.1 1.8 3
];
測試數據
testdata=[
5.1 3.5 1.4 0.2 1
4.9 3 1.4 0.2 1
4.7 3.2 1.3 0.2 1
4.6 3.1 1.5 0.2 1
5 3.6 1.4 0.2 1
5.4 3.9 1.7 0.4 1
4.6 3.4 1.4 0.3 1
5 3.4 1.5 0.2 1
4.4 2.9 1.4 0.2 1
4.9 3.1 1.5 0.1 1
5.4 3.7 1.5 0.2 1
4.8 3.4 1.6 0.2 1
4.8 3 1.4 0.1 1
4.3 3 1.1 0.1 1
5.8 4 1.2 0.2 1
5.7 4.4 1.5 0.4 1
5.4 3.9 1.3 0.4 1
5.1 3.5 1.4 0.3 1
5.7 3.8 1.7 0.3 1
5.1 3.8 1.5 0.3 1
5.4 3.4 1.7 0.2 1
5.1 3.7 1.5 0.4 1
4.6 3.6 1 0.2 1
5.1 3.3 1.7 0.5 1
4.8 3.4 1.9 0.2 1
7 3.2 4.7 1.4 2
6.4 3.2 4.5 1.5 2
6.9 3.1 4.9 1.5 2
5.5 2.3 4 1.3 2
6.5 2.8 4.6 1.5 2
5.7 2.8 4.5 1.3 2
6.3 3.3 4.7 1.6 2
4.9 2.4 3.3 1 2
6.6 2.9 4.6 1.3 2
5.2 2.7 3.9 1.4 2
5 2 3.5 1 2
5.9 3 4.2 1.5 2
6 2.2 4 1 2
6.1 2.9 4.7 1.4 2
5.6 2.9 3.6 1.3 2
6.7 3.1 4.4 1.4 2
5.6 3 4.5 1.5 2
5.8 2.7 4.1 1 2
6.2 2.2 4.5 1.5 2
5.6 2.5 3.9 1.1 2
5.9 3.2 4.8 1.8 2
6.1 2.8 4 1.3 2
6.3 2.5 4.9 1.5 2
6.1 2.8 4.7 1.2 2
6.4 2.9 4.3 1.3 2
6.3 3.3 6 2.5 3
5.8 2.7 5.1 1.9 3
7.1 3 5.9 2.1 3
6.3 2.9 5.6 1.8 3
6.5 3 5.8 2.2 3
7.6 3 6.6 2.1 3
4.9 2.5 4.5 1.7 3
7.3 2.9 6.3 1.8 3
6.7 2.5 5.8 1.8 3
7.2 3.6 6.1 2.5 3
6.5 3.2 5.1 2 3
6.4 2.7 5.3 1.9 3
6.8 3 5.5 2.1 3
5.7 2.5 5 2 3
5.8 2.8 5.1 2.4 3
6.4 3.2 5.3 2.3 3
6.5 3 5.5 1.8 3
7.7 3.8 6.7 2.2 3
7.7 2.6 6.9 2.3 3
6 2.2 5 1.5 3
6.9 3.2 5.7 2.3 3
5.6 2.8 4.9 2 3
7.7 2.8 6.7 2 3
6.3 2.7 4.9 1.8 3
6.7 3.3 5.7 2.1 3
];
算法講解
設 \(D\) 為用類別標簽 \(p_i\) 對訓練元組進行的划分,則 \(D\) 的信息熵表示為:
其中 \(p_i\) 表示標簽為某種類別的概率。這個例子中就是1、2、3共三類。
現在假設在屬性A上對D進行划分,划分成了v個子集,則樣本熵為:
相當於是划分的子集按照樣本數量加權求信息熵的和
然后計算在這一屬性上的這一種划分的信息增益:
ID3算法就是在每次需要分裂時,計算每個屬性的某種划分的信息增益,然后選擇信息增益最大的屬性的這種划分來進行分裂。
對於一個相同的訓練集 \(D\),\(H(D)\) 可以看作常數,這時只要 \(H_A(D)\) 最小即可。
如果要 \(H_A(D)\) 最小,那么前面有個負號,所以 \(H_A(D)\) 公式里面的 \(H(D_j)\) 是正的越大越好,再看看 \(H(D)\) 的公式,因為p是不大於1的正數,所以log一定不是正數,而當 \(p\) 趨近於 \(0^+\) 和 \(1\) 的時候,\(plog_2(p)\) 都是趨近於 \(0\) ,所以很明顯,樣本分布越均勻,它的信息熵就越大。
對於連續數據的情況怎么處理
上面的例子中,數據的四個屬性基本上可以看作是連續的了。拿訓練數據的第一個屬性來說,有4點幾、5點幾、6點幾、7點幾,基本上算是連續的了。如果把屬性A等於某個值作為決策樹的條件,那么就會出現問題,可能出現一個數據,它的屬性A有2位小數,這怎么辦?或者說還是1位,但是訓練數據中沒有它。這是數據連續和不連續不同的地方。離散的數據屬性往往是例如顏色={紅色、黃色、藍色......}這樣的。
那么對於連續的數據,先將D中元素按照某一屬性排序,則每兩個不同相鄰元素的中間點(相加除以2)可以看做潛在分裂點,然后按照小於等於和大於分成兩個集合,計算信息增益。然后信息增益最大的那個屬性的那一個分裂點就作為決策樹的分裂點了。
比如訓練數據中的屬性A,最小的是4.4,然后是4.5,所以把4.45當作一個分裂點,計算信息增益,如果時最大的,那么決策樹的這個分裂點就定為了:“屬性A,是否小於等於4.45”,所以最后的決策樹就時一棵二叉樹了。
具體實現
讀入數據
def get_data(file_path):
my_data = list()
with open(file_path, "r") as f:
for line in f.readlines():
if (is_number(line[0])):
line = line.strip('\n')
line = line.split('\t')
new_dict = dict()
new_dict['x'] = [float(x) for x in line[:4]]
new_dict['y'] = int(line[-1])
my_data.append(new_dict)
return my_data
把一行數據看成一個字典,x鍵是一個浮點數組,y鍵是一個整數。然后整個數據是一個列表,一個元素為字典的列表。
信息熵計算
def H(D):
tot = len(D)
p1 = 0
p2 = 0
p3 = 0
res = 0
for i in D:
if i['y'] == 1:
p1 = p1 + 1
if i['y'] == 2:
p2 = p2 + 1
if i['y'] == 3:
p3 = p3 + 1
if (p1 != 0):
p1 = p1 / float(tot)
res -= p1 * np.log(p1)
if (p2 != 0):
p2 = p2 / float(tot)
res -= p2 * np.log(p2)
if (p3 != 0):
p3 = p3 / float(tot)
res -= p3 * np.log(p3)
return res
按照公式 \(H(D)=-\sum_{i=1}^{m}{p_ilog_2(p_i)}\) 計算即可。
選擇信息增益最大的屬性和分裂值
直接暴力枚舉所有的情況,找到最大的信息增益(最小的樣本熵)。代碼如下:
# 按照條件,把集合分成兩部分
def split_data(D, featureID, value):
data1 = list()
data2 = list()
for dic in D:
if (dic['x'][featureID] <= value):
data1.append(dic)
else:
data2.append(dic)
return (data1, data2)
# 選擇信息增益最大的屬性和分裂值
def choose_best_split(D):
# 特征屬性的數量
num_features = 4
tot_D = float(len(D))
min_split_ENT = 999999.0
best_featureID = -1
best_split_value = -1
for i in range(num_features):
# 提取單個屬性
features = [dic['x'][i] for dic in D]
features.sort()
unique(features)
for j in range(len(features) - 1):
split_value = (features[j] + features[j + 1]) / 2.0
data1, data2 = split_data(D, i, split_value)
p1 = len(data1) / tot_D
p2 = len(data2) / tot_D
new_ENT = p1 * H(data1) + p2 * H(data2)
if new_ENT < min_split_ENT:
min_split_ENT = new_ENT
best_featureID = i
best_split_value = split_value
return best_featureID, best_split_value
遞歸創建決策樹
def create_tree(D):
# 遞歸終止條件,當前的數據的標簽都一樣
label_list = [dic['y'] for dic in D]
if (label_list.count(label_list[0]) == len(label_list)):
return label_list[0]
best_featureID, best_split_value = choose_best_split(D)
data1, data2 = split_data(D, best_featureID, best_split_value)
node_name = str(best_featureID + 1) + " <= " + str(best_split_value)
my_tree = {node_name: {}}
my_tree[node_name]['T'] = create_tree(data1)
my_tree[node_name]['F'] = create_tree(data2)
return my_tree
測試決策樹的准確率
def classify(x, tree):
if type(tree).__name__ == 'dict':
node = list(tree.keys())[0]
son = list(tree.values())[0]
else:
return tree
label = -1
key = int(node.split(' <= ')[0]) - 1
value = float(node.split(' <= ')[1])
if x['x'][key] <= value:
if type(son).__name__ == 'dict':
label = classify(x, son['T'])
else:
label = son['T']
else:
if type(son).__name__ == 'dict':
label = classify(x, son['F'])
else:
label = son['F']
return label
def test(D, tree):
tot = float(len(D))
pass_num = 0
for dic in D:
if (classify(dic, tree) == dic['y']):
pass_num = pass_num + 1
return pass_num / tot
測試結果
{
"3 <= 2.45": {
"T": 1,
"F": {
"3 <= 4.75": {
"T": 2,
"F": {
"4 <= 1.75": {
"T": {
"3 <= 5.05": {
"T": 2,
"F": {
"1 <= 6.05": {
"T": 2,
"F": 3
}
}
}
},
"F": 3
}
}
}
}
}
}
96.0%
正確率為96%,具體的決策樹如圖所示:

