python實現決策樹


參考:《機器學習實戰》- Machine Learning in Action

一、 基本思想

 我們所熟知的決策樹的形狀可能如下:

image_1bp6chq921vrmlba1tej1c4mrk29.png-43kB

 使用決策樹算法的目的就是生成類似於上圖的分類效果。所以算法的主要步驟就是如何去選擇結點。

 划分數據集的最大原則是:將無序的數據變得更加有序。我們可以使用多種方法划分數據集,但是每種方法都有各自的優缺點。集合信息的度量方式稱為香農熵

 偽代碼如下:

檢測數據集中的每個子項是否屬於同一分類;
    if so return 類標簽;
    else
        尋找划分數據集的最好特征
        划分數據集
        創建分支節點
            for 每個划分的子集
                調用createBranch並增加返回結果到分支節點中
        return 分支節點

一般而言,計算距離會采用歐式距離

二、 代碼

# -*- coding:utf8 -*-
import operator
from math import log

#計算信息熵
def calcShannonEnt(dataSet):
	numEntries = len(dataSet)
	labelCounts = {}
	for featVec in dataSet:
		currentLabel = featVec[-1]
		if currentLabel not in labelCounts.keys():
			labelCounts[currentLabel] = 0
		labelCounts[currentLabel] += 1

	shannonEnt = 0.0
	for key in labelCounts:
		prob = float(labelCounts[key])/numEntries
		shannonEnt -= prob*log(prob, 2)
	return shannonEnt

#按照給定特征划分數據集
def splitDataSet(dataSet, axis, value):
	retDataSet = []
	for featVec in dataSet:
		if featVec[axis] == value:
			reducedFeatVec = featVec[:axis]
			reducedFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedFeatVec)

	return retDataSet

#選擇最好的數據集划分方式
def chooseBestFeatureToSplit(dataSet):
	numFeatures = len(dataSet[0]) - 1
	baseEntropy = calcShannonEnt(dataSet)
	bestInfoGain = 0.0
	bestFeature = -1
	for i in range(numFeatures):
		featList = [example[i] for example in dataSet]
		uniqueVals = set(featList)
		newEntropy = 0.0
		for value in uniqueVals:
			subDataSet = splitDataSet(dataSet, i, value)
			prob = len(subDataSet)/float(len(dataSet))
			newEntropy += prob * calcShannonEnt(subDataSet)
		infoGain = baseEntropy - newEntropy
		if (infoGain > bestInfoGain):
			bestInfoGain = infoGain
			bestFeature = i
	return bestFeature


#構造決策樹
def majorityCnt(classList):
	classCount = {}
	for vote in classList:
		if vote not in classCount.keys():
			classCount[vote] = 0
		classCount[vote] += 1
		sortedClassCount = sorted(classCount.items(), \
			key=lambda item:item[1], reverse=True)

	return sortedClassCount[0][0]

def createTree(dataSet, labels):
	classList = [example[-1] for example in dataSet]
	if classList.count(classList[0]) == len(classList):
		return classList[0]
	if len(dataSet[0]) == 1:
		return majorityCnt(classList)
	bestFeat = chooseBestFeatureToSplit(dataSet)
	bestFeatLabel = labels[bestFeat]
	myTree = {bestFeatLabel:{}}
	del labels[bestFeat]
	featValues = [example[bestFeat] for example in dataSet]
	uniqueVals = set(featValues)
	for value in uniqueVals:
		subLabels = labels[:]
		myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)

	return myTree


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM