ID3算法的核心思想就是以信息增益度量属性选择,选择分裂后信息增益最大的属性进行分裂。
例子
训练数据
每一行代表一个数据,前4个元素表示输入,最后一个是标签。
traindata=[
5 3 1.6 0.2 1
5 3.4 1.6 0.4 1
5.2 3.5 1.5 0.2 1
5.2 3.4 1.4 0.2 1
4.7 3.2 1.6 0.2 1
4.8 3.1 1.6 0.2 1
5.4 3.4 1.5 0.4 1
5.2 4.1 1.5 0.1 1
5.5 4.2 1.4 0.2 1
4.9 3.1 1.5 0.2 1
5 3.2 1.2 0.2 1
5.5 3.5 1.3 0.2 1
4.9 3.6 1.4 0.1 1
4.4 3 1.3 0.2 1
5.1 3.4 1.5 0.2 1
5 3.5 1.3 0.3 1
4.5 2.3 1.3 0.3 1
4.4 3.2 1.3 0.2 1
5 3.5 1.6 0.6 1
5.1 3.8 1.9 0.4 1
4.8 3 1.4 0.3 1
5.1 3.8 1.6 0.2 1
4.6 3.2 1.4 0.2 1
5.3 3.7 1.5 0.2 1
5 3.3 1.4 0.2 1
6.6 3 4.4 1.4 2
6.8 2.8 4.8 1.4 2
6.7 3 5 1.7 2
6 2.9 4.5 1.5 2
5.7 2.6 3.5 1 2
5.5 2.4 3.8 1.1 2
5.5 2.4 3.7 1 2
5.8 2.7 3.9 1.2 2
6 2.7 5.1 1.6 2
5.4 3 4.5 1.5 2
6 3.4 4.5 1.6 2
6.7 3.1 4.7 1.5 2
6.3 2.3 4.4 1.3 2
5.6 3 4.1 1.3 2
5.5 2.5 4 1.3 2
5.5 2.6 4.4 1.2 2
6.1 3 4.6 1.4 2
5.8 2.6 4 1.2 2
5 2.3 3.3 1 2
5.6 2.7 4.2 1.3 2
5.7 3 4.2 1.2 2
5.7 2.9 4.2 1.3 2
6.2 2.9 4.3 1.3 2
5.1 2.5 3 1.1 2
5.7 2.8 4.1 1.3 2
7.2 3.2 6 1.8 3
6.2 2.8 4.8 1.8 3
6.1 3 4.9 1.8 3
6.4 2.8 5.6 2.1 3
7.2 3 5.8 1.6 3
7.4 2.8 6.1 1.9 3
7.9 3.8 6.4 2 3
6.4 2.8 5.6 2.2 3
6.3 2.8 5.1 1.5 3
6.1 2.6 5.6 1.4 3
7.7 3 6.1 2.3 3
6.3 3.4 5.6 2.4 3
6.4 3.1 5.5 1.8 3
6 3 4.8 1.8 3
6.9 3.1 5.4 2.1 3
6.7 3.1 5.6 2.4 3
6.9 3.1 5.1 2.3 3
5.8 2.7 5.1 1.9 3
6.8 3.2 5.9 2.3 3
6.7 3.3 5.7 2.5 3
6.7 3 5.2 2.3 3
6.3 2.5 5 1.9 3
6.5 3 5.2 2 3
6.2 3.4 5.4 2.3 3
5.9 3 5.1 1.8 3
];
测试数据
testdata=[
5.1 3.5 1.4 0.2 1
4.9 3 1.4 0.2 1
4.7 3.2 1.3 0.2 1
4.6 3.1 1.5 0.2 1
5 3.6 1.4 0.2 1
5.4 3.9 1.7 0.4 1
4.6 3.4 1.4 0.3 1
5 3.4 1.5 0.2 1
4.4 2.9 1.4 0.2 1
4.9 3.1 1.5 0.1 1
5.4 3.7 1.5 0.2 1
4.8 3.4 1.6 0.2 1
4.8 3 1.4 0.1 1
4.3 3 1.1 0.1 1
5.8 4 1.2 0.2 1
5.7 4.4 1.5 0.4 1
5.4 3.9 1.3 0.4 1
5.1 3.5 1.4 0.3 1
5.7 3.8 1.7 0.3 1
5.1 3.8 1.5 0.3 1
5.4 3.4 1.7 0.2 1
5.1 3.7 1.5 0.4 1
4.6 3.6 1 0.2 1
5.1 3.3 1.7 0.5 1
4.8 3.4 1.9 0.2 1
7 3.2 4.7 1.4 2
6.4 3.2 4.5 1.5 2
6.9 3.1 4.9 1.5 2
5.5 2.3 4 1.3 2
6.5 2.8 4.6 1.5 2
5.7 2.8 4.5 1.3 2
6.3 3.3 4.7 1.6 2
4.9 2.4 3.3 1 2
6.6 2.9 4.6 1.3 2
5.2 2.7 3.9 1.4 2
5 2 3.5 1 2
5.9 3 4.2 1.5 2
6 2.2 4 1 2
6.1 2.9 4.7 1.4 2
5.6 2.9 3.6 1.3 2
6.7 3.1 4.4 1.4 2
5.6 3 4.5 1.5 2
5.8 2.7 4.1 1 2
6.2 2.2 4.5 1.5 2
5.6 2.5 3.9 1.1 2
5.9 3.2 4.8 1.8 2
6.1 2.8 4 1.3 2
6.3 2.5 4.9 1.5 2
6.1 2.8 4.7 1.2 2
6.4 2.9 4.3 1.3 2
6.3 3.3 6 2.5 3
5.8 2.7 5.1 1.9 3
7.1 3 5.9 2.1 3
6.3 2.9 5.6 1.8 3
6.5 3 5.8 2.2 3
7.6 3 6.6 2.1 3
4.9 2.5 4.5 1.7 3
7.3 2.9 6.3 1.8 3
6.7 2.5 5.8 1.8 3
7.2 3.6 6.1 2.5 3
6.5 3.2 5.1 2 3
6.4 2.7 5.3 1.9 3
6.8 3 5.5 2.1 3
5.7 2.5 5 2 3
5.8 2.8 5.1 2.4 3
6.4 3.2 5.3 2.3 3
6.5 3 5.5 1.8 3
7.7 3.8 6.7 2.2 3
7.7 2.6 6.9 2.3 3
6 2.2 5 1.5 3
6.9 3.2 5.7 2.3 3
5.6 2.8 4.9 2 3
7.7 2.8 6.7 2 3
6.3 2.7 4.9 1.8 3
6.7 3.3 5.7 2.1 3
];
算法讲解
设 \(D\) 为用类别标签 \(p_i\) 对训练元组进行的划分,则 \(D\) 的信息熵表示为:
其中 \(p_i\) 表示标签为某种类别的概率。这个例子中就是1、2、3共三类。
现在假设在属性A上对D进行划分,划分成了v个子集,则样本熵为:
相当于是划分的子集按照样本数量加权求信息熵的和
然后计算在这一属性上的这一种划分的信息增益:
ID3算法就是在每次需要分裂时,计算每个属性的某种划分的信息增益,然后选择信息增益最大的属性的这种划分来进行分裂。
对于一个相同的训练集 \(D\),\(H(D)\) 可以看作常数,这时只要 \(H_A(D)\) 最小即可。
如果要 \(H_A(D)\) 最小,那么前面有个负号,所以 \(H_A(D)\) 公式里面的 \(H(D_j)\) 是正的越大越好,再看看 \(H(D)\) 的公式,因为p是不大于1的正数,所以log一定不是正数,而当 \(p\) 趋近于 \(0^+\) 和 \(1\) 的时候,\(plog_2(p)\) 都是趋近于 \(0\) ,所以很明显,样本分布越均匀,它的信息熵就越大。
对于连续数据的情况怎么处理
上面的例子中,数据的四个属性基本上可以看作是连续的了。拿训练数据的第一个属性来说,有4点几、5点几、6点几、7点几,基本上算是连续的了。如果把属性A等于某个值作为决策树的条件,那么就会出现问题,可能出现一个数据,它的属性A有2位小数,这怎么办?或者说还是1位,但是训练数据中没有它。这是数据连续和不连续不同的地方。离散的数据属性往往是例如颜色={红色、黄色、蓝色......}这样的。
那么对于连续的数据,先将D中元素按照某一属性排序,则每两个不同相邻元素的中间点(相加除以2)可以看做潜在分裂点,然后按照小于等于和大于分成两个集合,计算信息增益。然后信息增益最大的那个属性的那一个分裂点就作为决策树的分裂点了。
比如训练数据中的属性A,最小的是4.4,然后是4.5,所以把4.45当作一个分裂点,计算信息增益,如果时最大的,那么决策树的这个分裂点就定为了:“属性A,是否小于等于4.45”,所以最后的决策树就时一棵二叉树了。
具体实现
读入数据
def get_data(file_path):
my_data = list()
with open(file_path, "r") as f:
for line in f.readlines():
if (is_number(line[0])):
line = line.strip('\n')
line = line.split('\t')
new_dict = dict()
new_dict['x'] = [float(x) for x in line[:4]]
new_dict['y'] = int(line[-1])
my_data.append(new_dict)
return my_data
把一行数据看成一个字典,x键是一个浮点数组,y键是一个整数。然后整个数据是一个列表,一个元素为字典的列表。
信息熵计算
def H(D):
tot = len(D)
p1 = 0
p2 = 0
p3 = 0
res = 0
for i in D:
if i['y'] == 1:
p1 = p1 + 1
if i['y'] == 2:
p2 = p2 + 1
if i['y'] == 3:
p3 = p3 + 1
if (p1 != 0):
p1 = p1 / float(tot)
res -= p1 * np.log(p1)
if (p2 != 0):
p2 = p2 / float(tot)
res -= p2 * np.log(p2)
if (p3 != 0):
p3 = p3 / float(tot)
res -= p3 * np.log(p3)
return res
按照公式 \(H(D)=-\sum_{i=1}^{m}{p_ilog_2(p_i)}\) 计算即可。
选择信息增益最大的属性和分裂值
直接暴力枚举所有的情况,找到最大的信息增益(最小的样本熵)。代码如下:
# 按照条件,把集合分成两部分
def split_data(D, featureID, value):
data1 = list()
data2 = list()
for dic in D:
if (dic['x'][featureID] <= value):
data1.append(dic)
else:
data2.append(dic)
return (data1, data2)
# 选择信息增益最大的属性和分裂值
def choose_best_split(D):
# 特征属性的数量
num_features = 4
tot_D = float(len(D))
min_split_ENT = 999999.0
best_featureID = -1
best_split_value = -1
for i in range(num_features):
# 提取单个属性
features = [dic['x'][i] for dic in D]
features.sort()
unique(features)
for j in range(len(features) - 1):
split_value = (features[j] + features[j + 1]) / 2.0
data1, data2 = split_data(D, i, split_value)
p1 = len(data1) / tot_D
p2 = len(data2) / tot_D
new_ENT = p1 * H(data1) + p2 * H(data2)
if new_ENT < min_split_ENT:
min_split_ENT = new_ENT
best_featureID = i
best_split_value = split_value
return best_featureID, best_split_value
递归创建决策树
def create_tree(D):
# 递归终止条件,当前的数据的标签都一样
label_list = [dic['y'] for dic in D]
if (label_list.count(label_list[0]) == len(label_list)):
return label_list[0]
best_featureID, best_split_value = choose_best_split(D)
data1, data2 = split_data(D, best_featureID, best_split_value)
node_name = str(best_featureID + 1) + " <= " + str(best_split_value)
my_tree = {node_name: {}}
my_tree[node_name]['T'] = create_tree(data1)
my_tree[node_name]['F'] = create_tree(data2)
return my_tree
测试决策树的准确率
def classify(x, tree):
if type(tree).__name__ == 'dict':
node = list(tree.keys())[0]
son = list(tree.values())[0]
else:
return tree
label = -1
key = int(node.split(' <= ')[0]) - 1
value = float(node.split(' <= ')[1])
if x['x'][key] <= value:
if type(son).__name__ == 'dict':
label = classify(x, son['T'])
else:
label = son['T']
else:
if type(son).__name__ == 'dict':
label = classify(x, son['F'])
else:
label = son['F']
return label
def test(D, tree):
tot = float(len(D))
pass_num = 0
for dic in D:
if (classify(dic, tree) == dic['y']):
pass_num = pass_num + 1
return pass_num / tot
测试结果
{
"3 <= 2.45": {
"T": 1,
"F": {
"3 <= 4.75": {
"T": 2,
"F": {
"4 <= 1.75": {
"T": {
"3 <= 5.05": {
"T": 2,
"F": {
"1 <= 6.05": {
"T": 2,
"F": 3
}
}
}
},
"F": 3
}
}
}
}
}
}
96.0%
正确率为96%,具体的决策树如图所示: