Spark機器學習(6):決策樹算法


1. 決策樹基本知識

決策樹就是通過一系列規則對數據進行分類的一種算法,可以分為分類樹和回歸樹兩類,分類樹處理離散變量的,回歸樹是處理連續變量。

樣本一般都有很多個特征,有的特征對分類起很大的作用,有的特征對分類作用很小,甚至沒有作用。如決定是否對一個人貸款是,這個人的信用記錄、收入等就是主要的判斷依據,而性別、婚姻狀況等等就是次要的判斷依據。決策樹構建的過程,就是根據特征的決定性程度,先使用決定性程度高的特征分類,再使用決定性程度低的特征分類,這樣構建出一棵倒立的樹,就是我們需要的決策樹模型,可以用來對數據進行分類。

決策樹學習的過程可以分為三個步驟:1)特征選擇,即從眾多特征中選擇出一個作為當前節點的分類標准;2)決策樹生成,從上到下構建節點;3)剪枝,為了預防和消除過擬合,需要對決策樹剪枝。

2. 決策樹算法

主要的決策樹算法包括ID3、C4.5和CART。

ID3把信息增益作為選擇特征的標准。由於取值較多的特征(如學號)的信息增益比較大,這種算法會偏向於取值較多的特征。而且該算法只能用於離散型的數據,優點是不需要剪枝。

C4.5和ID3比較類似,區別在於使用信息增益比替代信息增益作為選擇特征的標准,因此比ID3更加科學,並且可以用於連續型的數據,但是需要剪枝。

CART(Classification And Regression Tree)采用的是Gini作為選擇的標准。Gini越大,說明不純度越大,這個特征就越不好。

3. MLlib的決策樹算法

MLlib的決策樹算法使用的隨機森林RandomForest的方法,不過並不是真正的隨機森林,因為實際上只有一棵決策樹。

直接上代碼:

import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

/**
  * Created by Administrator on 2017/7/6.
  */
object DecisionTreeTest {

  def main(args: Array[String]): Unit = {

    // 設置運行環境
    val conf = new SparkConf().setAppName("Decision Tree")
      .setMaster("spark://master:7077").setJars(Seq("E:\\Intellij\\Projects\\MachineLearning\\MachineLearning.jar"))
    val sc = new SparkContext(conf)
    Logger.getRootLogger.setLevel(Level.WARN)

    // 讀取樣本數據並解析
    val dataRDD = MLUtils.loadLibSVMFile(sc, "hdfs://master:9000/ml/data/sample_dt_data.txt")
    // 樣本數據划分,訓練樣本占0.8,測試樣本占0.2
    val dataParts = dataRDD.randomSplit(Array(0.8, 0.2))
    val trainRDD = dataParts(0)
    val testRDD = dataParts(1)

    // 決策樹參數
    val numClasses = 5
    val categoricalFeaturesInfo = Map[Int, Int]()
    val impurity = "gini"
    val maxDepth = 5
    val maxBins = 32
    // 建立決策樹模型並訓練
    val model = DecisionTree.trainClassifier(trainRDD, numClasses, categoricalFeaturesInfo,
      impurity, maxDepth, maxBins)

    // 對測試樣本進行測試
    val predictionAndLabel = testRDD.map { point =>
      val score = model.predict(point.features)
      (score, point.label, point.features)
    }
    val showPredict = predictionAndLabel.take(50)
    println("Prediction" + "\t" + "Label" + "\t" + "Data")
    for (i <- 0 to showPredict.length - 1) {
      println(showPredict(i)._1 + "\t" + showPredict(i)._2 + "\t" + showPredict(i)._3)
    }

    // 誤差計算
    val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / testRDD.count()
    println("Accuracy = " + accuracy)

    // 保存模型
    val ModelPath = "hdfs://master:9000/ml/model/Decision_Tree_Model"
    model.save(sc, ModelPath)
    val sameModel = DecisionTreeModel.load(sc, ModelPath)

  }

運行結果:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM