#####################################################################################################
-- coding: utf-8 --
"""
Created on Tue Aug 14 17:36:57 2018
@author: weixw
"""
import numpy as np
定義樹結構,采用的二叉樹,左子樹:條件為true,右子樹:條件為false
leftBranch:左子樹結點
rightBranch:右子樹結點
col:信息增益最大時對應的列索引
value:最優列索引下,划分數據類型的值
results:分類結果
summary:信息增益最大時樣本信息
data:信息增益最大時數據集
class Tree:
def init(self, leftBranch=None, rightBranch=None, col=-1, value=None, results=None, summary=None, data=None):
self.leftBranch = leftBranch
self.rightBranch = rightBranch
self.col = col
self.value = value
self.results = results
self.summary = summary
self.data = data
def __str__(self):
print(u"列號:%d" % self.col)
print(u"列划分值:%s" % self.value)
print(u"樣本信息:%s" % self.summary)
return ""
划分數據集
def splitDataSet(dataSet, value, column):
leftList = []
rightList = []
# 判斷value是否是數值型
if (isinstance(value, int) or isinstance(value, float)):
# 遍歷每一行數據
for rowData in dataSet:
# 如果某一行指定列值>=value,則將該行數據保存在leftList中,否則保存在rightList中
if (rowData[column] >= value):
leftList.append(rowData)
else:
rightList.append(rowData)
# value為標稱型
else:
# 遍歷每一行數據
for rowData in dataSet:
# 如果某一行指定列值==value,則將該行數據保存在leftList中,否則保存在rightList中
if (rowData[column] == value):
leftList.append(rowData)
else:
rightList.append(rowData)
return leftList, rightList
統計標簽類每個樣本個數
'''
該函數是計算gini值的輔助函數,假設輸入的dataSet為為['A', 'B', 'C', 'A', 'A', 'D'],
則輸出為['A':3,' B':1, 'C':1, 'D':1],這樣分類統計dataSet中每個類別的數量
'''
def calculateDiffCount(dataSet):
results = {}
for data in dataSet:
# data[-1] 是數據集最后一列,也就是標簽類
if data[-1] not in results:
results.setdefault(data[-1], 1)
else:
results[data[-1]] += 1
return results
基尼指數公式實現
def gini(dataSet):
# 計算gini的值(Calculate GINI)
# 數據所有行
length = len(dataSet)
# 標簽列合並后的數據集
results = calculateDiffCount(dataSet)
imp = 0.0
for i in results:
imp += results[i] / length * results[i] / length
return 1 - imp
生成決策樹
'''算法步驟'''
'''根據訓練數據集,從根結點開始,遞歸地對每個結點進行以下操作,構建二叉決策樹:
1 設結點的訓練數據集為D,計算現有特征對該數據集的信息增益。此時,對每一個特征A,對其可能取的
每個值a,根據樣本點對A >=a 的測試為“是”或“否”將D分割成D1和D2兩部分,利用基尼指數計算信息增益。
2 在所有可能的特征A以及它們所有可能的切分點a中,選擇信息增益最大的特征及其對應的切分點作為最優特征
與最優切分點,依據最優特征與最優切分點,從現結點生成兩個子結點,將訓練數據集依特征分配到兩個子結點中去。
3 對兩個子結點遞歸地調用1,2,直至滿足停止條件。
4 生成CART決策樹。
'''''''''''''''''''''
evaluationFunc= gini :采用的是基尼指數來衡量信息關注度
def buildDecisionTree(dataSet, evaluationFunc=gini):
# 計算基礎數據集的基尼指數
baseGain = evaluationFunc(dataSet)
# 計算每一行的長度(也就是列總數)
columnLength = len(dataSet[0])
# 計算數據項總數
rowLength = len(dataSet)
# 初始化
bestGain = 0.0 # 信息增益最大值
bestValue = None # 信息增益最大時的列索引,以及划分數據集的樣本值
bestSet = None # 信息增益最大,聽過樣本值划分數據集后的數據子集
# 標簽列除外(最后一列),遍歷每一列數據
for col in range(columnLength - 1):
# 獲取指定列數據
colSet = [example[col] for example in dataSet]
# 獲取指定列樣本唯一值
uniqueColSet = set(colSet)
# 遍歷指定列樣本集
for value in uniqueColSet:
# 分割數據集
leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
# 計算子數據集概率,python3 "/"除號結果為小數
prop = len(leftDataSet) / rowLength
# 計算信息增益
infoGain = baseGain - prop * evaluationFunc(leftDataSet) - (1 - prop) * evaluationFunc(rightDataSet)
# 找出信息增益最大時的列索引,value,數據子集
if (infoGain > bestGain):
bestGain = infoGain
bestValue = (col, value)
bestSet = (leftDataSet, rightDataSet)
# 結點信息
# nodeDescription = {'impurity:%.3f'%baseGain,'sample:%d'%rowLength}
nodeDescription = {'impurity': '%.3f' % baseGain, 'sample': '%d' % rowLength}
# 數據行標簽類別不一致,可以繼續分類
# 遞歸必須有終止條件
if bestGain > 0:
# 遞歸,生成左子樹結點,右子樹結點
leftBranch = buildDecisionTree(bestSet[0], evaluationFunc)
rightBranch = buildDecisionTree(bestSet[1], evaluationFunc)
return Tree(leftBranch=leftBranch, rightBranch=rightBranch, col=bestValue[0]
, value=bestValue[1], summary=nodeDescription, data=bestSet)
else:
# 數據行標簽類別都相同,分類終止
return Tree(results=calculateDiffCount(dataSet), summary=nodeDescription, data=dataSet)
def createTree(dataSet, evaluationFunc=gini):
# 遞歸建立決策樹, 當gain=0,時停止回歸
# 計算基礎數據集的基尼指數
baseGain = evaluationFunc(dataSet)
# 計算每一行的長度(也就是列總數)
columnLength = len(dataSet[0])
# 計算數據項總數
rowLength = len(dataSet)
# 初始化
bestGain = 0.0 # 信息增益最大值
bestValue = None # 信息增益最大時的列索引,以及划分數據集的樣本值
bestSet = None # 信息增益最大,聽過樣本值划分數據集后的數據子集
# 標簽列除外(最后一列),遍歷每一列數據
for col in range(columnLength - 1):
# 獲取指定列數據
colSet = [example[col] for example in dataSet]
# 獲取指定列樣本唯一值
uniqueColSet = set(colSet)
# 遍歷指定列樣本集
for value in uniqueColSet:
# 分割數據集
leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
# 計算子數據集概率,python3 "/"除號結果為小數
prop = len(leftDataSet) / rowLength
# 計算信息增益
infoGain = baseGain - prop * evaluationFunc(leftDataSet) - (1 - prop) * evaluationFunc(rightDataSet)
# 找出信息增益最大時的列索引,value,數據子集
if (infoGain > bestGain):
bestGain = infoGain
bestValue = (col, value)
bestSet = (leftDataSet, rightDataSet)
impurity = u'%.3f' % baseGain
sample = '%d' % rowLength
if bestGain > 0:
bestFeatLabel = u'serial:%s\nimpurity:%s\nsample:%s' % (bestValue[0], impurity, sample)
myTree = {bestFeatLabel: {}}
myTree[bestFeatLabel][bestValue[1]] = createTree(bestSet[0], evaluationFunc)
myTree[bestFeatLabel]['no'] = createTree(bestSet[1], evaluationFunc)
return myTree
else: # 遞歸需要返回值
bestFeatValue = u'%s\nimpurity:%s\nsample:%s' % (str(calculateDiffCount(dataSet)), impurity, sample)
return bestFeatValue
分類測試:
'''根據給定測試數據遍歷二叉樹,找到符合條件的葉子結點'''
'''例如測試數據為[5.9,3,4.2,1.75],按照訓練數據生成的決策樹分類的順序為
第2列對應測試數據4.2 =>與決策樹根結點(2)的value(3)比較,>=3則遍歷左子樹,否則遍歷右子樹,
葉子結點就是結果'''
def classify(data, tree):
# 判斷是否是葉子結點,是就返回葉子結點相關信息,否就繼續遍歷
if tree.results != None:
return u"%s\n%s" % (tree.results, tree.summary)
else:
branch = None
v = data[tree.col]
# 數值型數據
if isinstance(v, int) or isinstance(v, float):
if v >= tree.value:
branch = tree.leftBranch
else:
branch = tree.rightBranch
else: # 標稱型數據
if v == tree.value:
branch = tree.leftBranch
else:
branch = tree.rightBranch
return classify(data, branch)
def loadCSV(fileName):
def convertTypes(s):
s = s.strip()
try:
return float(s) if '.' in s else int(s)
except ValueError:
return s
data = np.loadtxt(fileName, dtype='str', delimiter=',')
data = data[1:, :]
dataSet = ([[convertTypes(item) for item in row] for row in data])
return dataSet
多數表決器
列中相同值數量最多為結果
def majorityCnt(classList):
import operator
classCounts = {}
for value in classList:
if (value not in classCounts.keys()):
classCounts[value] = 0
classCounts[value] += 1
sortedClassCount = sorted(classCounts.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
剪枝算法(前序遍歷方式:根=>左子樹=>右子樹)
'''算法步驟
- 從二叉樹的根結點出發,遞歸調用剪枝算法,直至左、右結點都是葉子結點
- 計算父節點(子結點為葉子結點)的信息增益infoGain
- 如果infoGain < miniGain,則選取樣本多的葉子結點來取代父節點
- 循環1,2,3,直至遍歷完整棵樹
'''''''''
def prune(tree, miniGain, evaluationFunc=gini):
print(u"當前結點信息:")
print(str(tree))
# 如果當前結點的左子樹不是葉子結點,遍歷左子樹
if (tree.leftBranch.results == None):
print(u"左子樹結點信息:")
print(str(tree.leftBranch))
prune(tree.leftBranch, miniGain, evaluationFunc)
# 如果當前結點的右子樹不是葉子結點,遍歷右子樹
if (tree.rightBranch.results == None):
print(u"右子樹結點信息:")
print(str(tree.rightBranch))
prune(tree.rightBranch, miniGain, evaluationFunc)
# 左子樹和右子樹都是葉子結點
if (tree.leftBranch.results != None and tree.rightBranch.results != None):
# 計算左葉子結點數據長度
leftLen = len(tree.leftBranch.data)
# 計算右葉子結點數據長度
rightLen = len(tree.rightBranch.data)
# 計算左葉子結點概率
leftProp = leftLen / (leftLen + rightLen)
# 計算該結點的信息增益(子類是葉子結點)
infoGain = (evaluationFunc(tree.leftBranch.data + tree.rightBranch.data) -
leftProp * evaluationFunc(tree.leftBranch.data) - (1 - leftProp) * evaluationFunc(
tree.rightBranch.data))
# 信息增益 < 給定閾值,則說明葉子結點與其父結點特征差別不大,可以剪枝
if (infoGain < miniGain):
# 合並左右葉子結點數據
dataSet = tree.leftBranch.data + tree.rightBranch.data
# 獲取標簽列
classLabels = [example[-1] for example in dataSet]
# 找到樣本最多的標簽值
keyLabel = majorityCnt(classLabels)
# 判斷標簽值是左右葉子結點哪一個
if keyLabel in tree.leftBranch.results:
# 左葉子結點取代父結點
tree.data = tree.leftBranch.data
tree.results = tree.leftBranch.results
tree.summary = tree.leftBranch.summary
else:
# 右葉子結點取代父結點
tree.data = tree.rightBranch.data
tree.results = tree.rightBranch.results
tree.summary = tree.rightBranch.summary
tree.leftBranch = None
tree.rightBranch = None
########################################################################################################
################################################################################################
'''
Created on Oct 14, 2010
@author: Peter Harrington
'''
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="circle", fc="0.7")
arrow_args = dict(arrowstyle="<-")
獲取樹的葉子節點
def getNumLeafs(myTree):
numLeafs = 0
#dict轉化為list
firstSides = list(myTree.keys())
firstStr = firstSides[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#判斷是否是葉子節點(通過類型判斷,子類不存在,則類型為str;子類存在,則為dict)
if type(secondDict[key]).name=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
獲取樹的層數
def getTreeDepth(myTree):
maxDepth = 0
#dict轉化為list
firstSides = list(myTree.keys())
firstStr = firstSides[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).name=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstSides = list(myTree.keys())
firstStr = firstSides[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).name=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
if you do get a dictonary you know it's a tree, and the first element will be another dict
繪制決策樹 樣例1
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#寬,高間距
plotTree.totalW = float(getNumLeafs(inTree))-3
plotTree.totalD = float(getTreeDepth(inTree))-2
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.95,1.0), '')
plt.show()
繪制決策樹 樣例2
def createPlot1(inTree):
fig = plt.figure(1, facecolor='white')
fig = plt.figure(dpi=255)
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#寬,高間距
plotTree.totalW = float(getNumLeafs(inTree))-4.5
plotTree.totalD = float(getTreeDepth(inTree)) -3
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (1.0,1.0), '')
plt.show()
繪制樹的根節點和葉子節點(根節點形狀:長方形,葉子節點:橢圓形)
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
thisTree = retrieveTree(0)
createPlot(thisTree)
createPlot()
myTree = retrieveTree(0)
numLeafs =getNumLeafs(myTree)
treeDepth =getTreeDepth(myTree)
print(u"葉子節點數目:%d"% numLeafs)
print(u"樹深度:%d"%treeDepth)
##########################################################################################################
###################################################################################################
-- coding: utf-8 --
"""
Created on Wed Aug 15 14:16:59 2018
@author: weixw
"""
import Demo_1.myCart as mc
from Demo_1.myCart import gini
if name == 'main':
import treePlotter as tp
dataSet = mc.loadCSV("F:\C盤移過來的文件\dataSet.csv")
myTree = mc.createTree(dataSet, evaluationFunc=gini)
print(u"myTree:%s"%myTree)
#繪制決策樹
print(u"繪制決策樹:")
tp.createPlot1(myTree)
decisionTree = mc.buildDecisionTree(dataSet, evaluationFunc=gini)
testData = [5.9,3,4.2,1.75]
r = mc.classify(testData, decisionTree)
print(u"分類后測試結果:")
print(r)
print()
mc.prune(decisionTree, 0.4)
r1 = mc.classify(testData, decisionTree)
print(u"剪枝后測試結果:")
print(r1)
################################################################################################################