支持連續變量和類別變量,類別變量就是某個屬性有三個值,a,b,c,需要用Feature Transformers中的vectorindexer處理
上來是一堆參數
setMaxDepth:最大樹深度
setMaxBins:最大裝箱數,為了近似統計變量,比如變量有100個值,我只分成10段去做統計
setMinInstancesPerNode:每個節點最少實例
setMinInfoGain:最小信息增益
setMaxMemoryInMB:最大內存MB單位,這個值越大,一次處理的節點划分就越多
setCacheNodeIds:是否緩存節點id,緩存可以加速深層樹的訓練
setCheckpointInterval:檢查點間隔,就是多少次迭代固化一次
setImpurity:隨機森林有三種方式,entropy,gini,variance,回歸肯定就是variance
setSubsamplingRate:設置采樣率,就是每次選多少比例的樣本構成新樹
setSeed:采樣種子,種子不變,采樣結果不變
setNumTrees:設置森林里有多少棵樹
setFeatureSubsetStrategy:設置特征子集選取策略,隨機森林就是兩個隨機,構成樹的樣本隨機,每棵樹開始分裂的屬性是隨機的,其他跟決策樹區別不大,注釋這么寫的
* The number of features to consider for splits at each tree node.
* Supported options:
* - "auto": Choose automatically for task://默認策略
* If numTrees == 1, set to "all." //決策樹選擇所有屬性
* If numTrees > 1 (forest), set to "sqrt" for classification and //決策森林 分類選擇屬性數開平方,回歸選擇三分之一屬性
* to "onethird" for regression.
* - "all": use all features
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features) //還有取對數的
* (default = "auto")
*
* These various settings are based on the following references:
* - log2: tested in Breiman (2001)
* - sqrt: recommended by Breiman manual for random forests
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
* package.
參數完畢,下面比較重要的是這段代碼
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
這個地比較蛋疼的是dataset.schema($(featuresCol))
/** An alias for [[getOrDefault()]]. */
protected final def $[T](param: Param[T]): T = getOrDefault(param)
這段代碼說明了$(featuresCol))只是求出一個字段名,實戰中直接data.schema("features") ,data.schema("features")出來的是StructField,
case classStructField(name: String, dataType: DataType, nullable: Boolean = true, metadata: Metadata = Metadata.empty) extendsProduct with Serializable
StructField包含四個內容,最好知道一下,機器學習代碼很多都用
回頭說下getCategoricalFeatures,這個方法是識別一個屬性是二值變量還是名義變量,例如a,b就是二值變量,a,b,c就是名義變量,最終把屬性索引和變量值的個數放到一個map
這個函數的功能和vectorindexer類似,但是一般都用vectorindexer,因為實戰中我們大都從sql讀數據,sql讀出來的數據metadata是空,無法識別二值變量還是名義變量
后面是
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeRegressionModel])
val numFeatures = oldDataset.first().features.size
new RandomForestRegressionModel(trees, numFeatures)
可以看出還是調的RDD的舊方法,run這個方法是核心有1000多行,后面會詳細跟蹤,最后返回的是RandomForestRegressionModel,里面有Array[DecisionTreeRegressionModel] ,就是生成的一組決策樹模型,也就是決策森林,另外一個是屬性數,我們繼續看RandomForestRegressionModel
在1.6版本每棵樹的權重都是1,里面還有這么一個方法
override protected def transformImpl(dataset: DataFrame): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
可以看到把模型通過廣播的形式傳給exectors,搞一個udf預測函數,最后通過withColumn把預測數據粘到原數據后面,
注意這個寫法dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) ,第一個參數是列名,第二個是計算出來的col,col是列類型,預測方法如下
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
_trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
}
可見預測用的是每個樹的跟節點,predictImpl(features)返回這個根節點分配的葉節點,這是一個遞歸調用的過程,關於如何遞歸,后面也會拿出來細說,最后再用.prediction方法把所有樹預測的結果相加求平均
后面有一個計算各屬性重要性的方法
lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
實現如下
private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
val totalImportances = new OpenHashMap[Int, Double]()
trees.foreach { tree =>
// Aggregate feature importance vector for this tree 先計算每棵樹的屬性重要性值
val importances = new OpenHashMap[Int, Double]()
computeFeatureImportance(tree.rootNode, importances)
// Normalize importance vector for this tree, and add it to total.
// TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
val treeNorm = importances.map(_._2).sum
if (treeNorm != 0) {
importances.foreach { case (idx, impt) =>
val normImpt = impt / treeNorm
totalImportances.changeValue(idx, normImpt, _ + normImpt)
}
}
}
// Normalize importances
normalizeMapValues(totalImportances)
// Construct vector
val d = if (numFeatures != -1) {
numFeatures
} else {
// Find max feature index used in trees
val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
maxFeatureIndex + 1
}
if (d == 0) {
assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
s" importance: No splits in forest, but some non-zero importances.")
}
val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
Vectors.sparse(d, indices.toArray, values.toArray)
}
computeFeatureImportance的實現如下
/**
* Recursive method for computing feature importances for one tree.
* This walks down the tree, adding to the importance of 1 feature at each node.
* @param node Current node in recursion
* @param importances Aggregate feature importances, modified by this method
*/
private[impl] def computeFeatureImportance(
node: Node,
importances: OpenHashMap[Int, Double]): Unit = {
node match {
case n: InternalNode =>
val feature = n.split.featureIndex
val scaledGain = n.gain * n.impurityStats.count
importances.changeValue(feature, scaledGain, _ + scaledGain)
computeFeatureImportance(n.leftChild, importances)
computeFeatureImportance(n.rightChild, importances)
case n: LeafNode =>
// do nothing
}
}
由於屬性重要性是由gain概念擴展而來,這里以gain來說明如何計算屬性重要性。
這里首先可以看出為什么每次樹的調用都回到rootnode的調用,因為要遞歸的沿着樹的層深往下游走,這里游走到葉節點什么也不做,其他分裂節點也就是代碼里的InternalNode ,先找到該節點划分的屬性索引,然后該節點增益乘以該節點數據量,然后更新屬性重要性值,這樣繼續遞歸左節點,右節點,直到結束
然后回到featureImportances方法,val treeNorm = importances.map(_._2).sum是把剛才計算的每棵樹的屬性重要性求和,然后計算每個屬性重要性占這棵樹總重要性的比值,這樣整棵樹就搞完了,foreach走完,所有樹的屬性重要性就累加到totalImportances里了,然后normalizeMapValues(totalImportances)再按剛才的方法算一遍,這樣出來的屬性值和就為1了,有了屬性個數和排好序的屬性重要性值,裝入向量,就是最終輸出的結果
入口方法就這些了
現在我們還有run方法的1000多行,還有如何遞歸分配節點這兩個點需要講,后面會繼續