Spark DecisionTreeClassifier 決策樹分類


1、概述

決策樹及樹集(算法)是用於機器學習任務的分類和回歸的流行方法。決策樹被廣泛使用,因為它們易於解釋,處理分類特征,擴展到多類分類設置,不需要特征縮放,並且能夠捕獲非線性和特征交互。樹集分類算法(例如隨機森林和boosting)在分類和回歸任務中表現最佳。
spark.ml實現使用連續和分類特征,支持用於二元分類和多類分類以及用於回歸的決策樹。該實現按行對數據進行分區,從而允許對數百萬甚至數十億個實例進行分布式訓練。

 

2、輸入和輸出

所有輸出列都是可選的;要排除輸出列,請將其對應的Param設置為空字符串。

Input Columns

Param name Type(s) Default Description
labelCol Double "label" Label to predict
featuresCol Vector "features" Feature vector

Output Columns

Param name Type(s) Default Description Notes
predictionCol Double "prediction" Predicted label  
rawPredictionCol Vector "rawPrediction" Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction Classification only
probabilityCol Vector "probability" Vector of length # classes equal to rawPrediction normalized to a multinomial distribution Classification only
varianceCol Double   The biased sample variance of prediction Regression only

3、code

package com.home.spark.ml

import org.apache.spark.SparkConf
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.regression.DecisionTreeRegressor
import org.apache.spark.sql.{Dataset, Row, SparkSession}

object Ex_DecisionTree {
  def main(args: Array[String]): Unit = {
    val conf: SparkConf = new SparkConf(true).setMaster("local[2]").setAppName("spark ml")
    val spark = SparkSession.builder().config(conf).getOrCreate()

    //rdd轉換成df或者ds需要SparkSession實例的隱式轉換
    //導入隱式轉換,注意這里的spark不是包名,而是SparkSession的對象名
    import spark.implicits._

    val data = spark.sparkContext.textFile("input/iris.data.txt")
      .map(_.split(","))
      .map(a => Iris(
        Vectors.dense(a(0).toDouble, a(1).toDouble, a(2).toDouble, a(3).toDouble),
        a(4))
      ).toDF()

    data.createOrReplaceTempView("iris")
    val df = spark.sql("select * from iris")
    df.map(r => r(1) + " : " + r(0)).collect().take(10).foreach(println)


    ////對特征列和標簽列進行索引轉換
    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
    val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures")
      .setMaxCategories(4).fit(df)


    //決策樹分類器
    val dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")

    //將預測的類別重新轉成字符型
    val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictionLabel").setLabels(labelIndexer.labels)

    //將原數據集拆分成兩個部分,一部分用於訓練,一部分用於測試
    val Array(trainingData, testData): Array[Dataset[Row]] = df.randomSplit(Array(0.7,0.3))

    //建立工作流
    val pipeline = new Pipeline().setStages(Array(labelIndexer,featureIndexer,dtClassifier,labelConverter))

    //生成訓練模型
    val modelDecisionTreeClassifier = pipeline.fit(trainingData)

    //預測
    val result = modelDecisionTreeClassifier.transform(testData)

    result.show(150,false)

    /**
      * 樣本分為:正類樣本和負類樣本。
      * TP:被分類器正確分類的正類樣本數。
      * TN: 被分類器正確分類的負類樣本數。
      * FP: 被分類器錯誤分類的正類樣本數。(本來是負,被預測為正) ---------->正
      * FN: 被分類器錯誤分類的負類樣本數。 (本來是正, 被預測為負) ---------->負
      *
      * 准確率(Accuracy ACC)
      * 總樣本數=TP+TN+FP+FN
      * ACC=(TP+TN)/(總樣本數)
      * 該評價指標主要針對分類均勻的數據集。
      */
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy: Double = evaluator.evaluate(result)

    println("Accuracy = " + accuracy)

    /**
      * 精確率(Precision 查准率)
      * Precision = TP / (TP+ FP) 准確率,表示模型預測為正樣本的樣本中真正為正的比例
      */
    val evaluator2 = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
      .setMetricName("weightedPrecision")
    val weightedPrecision: Double = evaluator2.evaluate(result)

    println("weightedPrecision = " + weightedPrecision)

    /**
      * 召回率(查全率)
      * Recall = TP /(TP + FN) 召回率,表示模型准確預測為正樣本的數量占所有正樣本數量的比例
      */
    val evaluator3 = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
      .setMetricName("weightedRecall")
    val weightedRecall: Double = evaluator3.evaluate(result)

    println("weightedRecall = " + weightedRecall)


    val treeModel = modelDecisionTreeClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
    println("Learned classification tree model:\n" + treeModel.toDebugString)

    //決策樹回歸器
    val dtRegressor = new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")

    val pipelineRegressor = new Pipeline()
      .setStages(Array(labelIndexer,featureIndexer,dtRegressor,labelConverter))

    val modelRegressor = pipelineRegressor.fit(trainingData)
    val result2 = modelRegressor.transform(testData)

    result2.show(150,false)

    //評估
    val regressionEvaluator = new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
        .setMetricName("rmse")
    val rmse = regressionEvaluator.evaluate(result2)
    println("rmse = " + rmse)
    spark.stop()
  }
}

case class Iris(features: Vector, label: String)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM