最近在用Spark MLlib進行特征處理時,對於StringIndexer和IndexToString遇到了點問題,查閱官方文檔也沒有解決疑惑。無奈之下翻看源碼才明白其中一二...這就給大家娓娓道來。
更多內容參考我的大數據學習之路
文檔說明
StringIndexer 字符串轉索引
StringIndexer可以把字符串的列按照出現頻率進行排序,出現次數最高的對應的Index為0。比如下面的列表進行StringIndexer
id | category |
---|---|
0 | a |
1 | b |
2 | c |
3 | a |
4 | a |
5 | c |
就可以得到如下:
id | category | categoryIndex |
---|---|---|
0 | a | 0.0 |
1 | b | 2.0 |
2 | c | 1.0 |
3 | a | 0.0 |
4 | a | 0.0 |
5 | c | 1.0 |
可以看到出現次數最多的"a",索引為0;次數最少的"b"索引為2。
針對訓練集中沒有出現的字符串值,spark提供了幾種處理的方法:
- error,直接拋出異常
- skip,跳過該樣本數據
- keep,使用一個新的最大索引,來表示所有未出現的值
下面是基於Spark MLlib 2.2.0的代碼樣例:
package xingoo.ml.features.tranformer
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.StringIndexer
object StringIndexerTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("string-indexer").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val df = spark.createDataFrame(
Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
).toDF("id", "category")
val df1 = spark.createDataFrame(
Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "e"), (5, "f"))
).toDF("id", "category")
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.setHandleInvalid("keep") //skip keep error
val model = indexer.fit(df)
val indexed = model.transform(df1)
indexed.show(false)
}
}
得到的結果為:
+---+--------+-------------+
|id |category|categoryIndex|
+---+--------+-------------+
|0 |a |0.0 |
|1 |b |2.0 |
|2 |c |1.0 |
|3 |a |0.0 |
|4 |e |3.0 |
|5 |f |3.0 |
+---+--------+-------------+
IndexToString 索引轉字符串
這個索引轉回字符串要搭配前面的StringIndexer一起使用才行:
package xingoo.ml.features.tranformer
import org.apache.spark.ml.attribute.Attribute
import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
import org.apache.spark.sql.SparkSession
object IndexToString2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val df = spark.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category")
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df)
val indexed = indexer.transform(df)
println(s"Transformed string column '${indexer.getInputCol}' " +
s"to indexed column '${indexer.getOutputCol}'")
indexed.show()
val inputColSchema = indexed.schema(indexer.getOutputCol)
println(s"StringIndexer will store labels in output column metadata: " +
s"${Attribute.fromStructField(inputColSchema).toString}\n")
val converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory")
val converted = converter.transform(indexed)
println(s"Transformed indexed column '${converter.getInputCol}' back to original string " +
s"column '${converter.getOutputCol}' using labels in metadata")
converted.select("id", "categoryIndex", "originalCategory").show()
}
}
得到的結果如下:
Transformed string column 'category' to indexed column 'categoryIndex'
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
| 0| a| 0.0|
| 1| b| 2.0|
| 2| c| 1.0|
| 3| a| 0.0|
| 4| a| 0.0|
| 5| c| 1.0|
+---+--------+-------------+
StringIndexer will store labels in output column metadata: {"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}
Transformed indexed column 'categoryIndex' back to original string column 'originalCategory' using labels in metadata
+---+-------------+----------------+
| id|categoryIndex|originalCategory|
+---+-------------+----------------+
| 0| 0.0| a|
| 1| 2.0| b|
| 2| 1.0| c|
| 3| 0.0| a|
| 4| 0.0| a|
| 5| 1.0| c|
+---+-------------+----------------+
使用問題
假如處理的過程很復雜,重新生成了一個DataFrame,此時想要把這個DataFrame基於IndexToString轉回原來的字符串怎么辦呢? 先來試試看:
package xingoo.ml.features.tranformer
import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
import org.apache.spark.sql.SparkSession
object IndexToString3 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("dct").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val df = spark.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category")
val df2 = spark.createDataFrame(Seq(
(0, 2.0),
(1, 1.0),
(2, 1.0),
(3, 0.0)
)).toDF("id", "index")
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df)
val indexed = indexer.transform(df)
val converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory")
val converted = converter.transform(df2)
converted.show()
}
}
運行后發現異常:
18/07/05 20:20:32 INFO StateStoreCoordinatorRef: Registered StateStoreCoordinator endpoint
Exception in thread "main" java.lang.IllegalArgumentException: Field "categoryIndex" does not exist.
at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
at scala.collection.AbstractMap.getOrElse(Map.scala:59)
at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)
at org.apache.spark.ml.feature.IndexToString.transformSchema(StringIndexer.scala:338)
at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:74)
at org.apache.spark.ml.feature.IndexToString.transform(StringIndexer.scala:352)
at xingoo.ml.features.tranformer.IndexToString3$.main(IndexToString3.scala:37)
at xingoo.ml.features.tranformer.IndexToString3.main(IndexToString3.scala)
這是為什么呢?跟隨源碼來看吧!
源碼剖析
首先我們創建一個DataFrame,獲得原始數據:
val df = spark.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category")
然后創建對應的StringIndexer:
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.setHandleInvalid("skip")
.fit(df)
這里面的fit就是在訓練轉換器了,進入fit():
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
// 這里針對需要轉換的列先強制轉換成字符串,然后遍歷統計每個字符串出現的次數
val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
.countByValue()
// counts是一個map,里面的內容為{a->3, b->1, c->2}
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
// 按照個數大小排序,返回數組,[a, c, b]
// 把這個label保存起來,並返回對應的model(mllib里邊的模型都是這個套路,跟sklearn學的)
copyValues(new StringIndexerModel(uid, labels).setParent(this))
}
這樣就得到了一個列表,列表里面的內容是[a, c, b],然后執行transform來進行轉換:
val indexed = indexer.transform(df)
這個transform可想而知就是用這個數組對每一行的該列進行轉換,但是它其實還做了其他的事情:
override def transform(dataset: Dataset[_]): DataFrame = {
...
// --------
// 通過label生成一個Metadata,這個很關鍵!!!
// metadata其實是一個map,內容為:
// {"ml_attr":{"vals":["a","c","b"],"type":"nominal","name":"categoryIndex"}}
// --------
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// 如果是skip則過濾一些數據
...
// 下面是針對不同的情況處理轉換的列,邏輯很簡單
val indexer = udf { label: String =>
...
if (labelToIndex.contains(label)) {
labelToIndex(label) //如果正常,就進行轉換
} else if (keepInvalid) {
labels.length // 如果是keep,就返回索引的最大值(即數組的長度)
} else {
... // 如果是error,就拋出異常
}
}
// 保留之前所有的列,新增一個字段,並設置字段的StructField中的Metadata!!!!
// 並設置字段的StructField中的Metadata!!!!
// 並設置字段的StructField中的Metadata!!!!
// 並設置字段的StructField中的Metadata!!!!
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}
看到了嗎!關鍵的地方在這里,給新增加的字段的類型StructField設置了一個Metadata。這個Metadata正常都是空的{}
,但是這里設置了metadata之后,里面包含了label數組的信息。
接下來看看IndexToString是怎么用的,由於IndexToString是一個Transformer,因此只有一個trasform方法:
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
// 關鍵是這里:
// 如果IndexToString設置了labels數組,就直接返回;
// 否則,就讀取了傳入的DataFrame的StructField中的Metadata
val values = if (!isDefined(labels) || $(labels).isEmpty) {
Attribute.fromStructField(inputColSchema)
.asInstanceOf[NominalAttribute].values.get
} else {
$(labels)
}
// 基於這個values把index轉成對應的值
val indexer = udf { index: Double =>
val idx = index.toInt
if (0 <= idx && idx < values.length) {
values(idx)
} else {
throw new SparkException(s"Unseen index: $index ??")
}
}
val outputColName = $(outputCol)
dataset.select(col("*"),
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}
了解StringIndexer和IndexToString的原理機制后,就可以作出如下的應對策略了。
1 增加StructField的MetaData信息
val df2 = spark.createDataFrame(Seq(
(0, 2.0),
(1, 1.0),
(2, 1.0),
(3, 0.0)
)).toDF("id", "index").select(col("*"),col("index").as("formated_index", indexed.schema("categoryIndex").metadata))
val converter = new IndexToString()
.setInputCol("formated_index")
.setOutputCol("origin_col")
val converted = converter.transform(df2)
converted.show(false)
+---+-----+--------------+----------+
|id |index|formated_index|origin_col|
+---+-----+--------------+----------+
|0 |2.0 |2.0 |b |
|1 |1.0 |1.0 |c |
|2 |1.0 |1.0 |c |
|3 |0.0 |0.0 |a |
+---+-----+--------------+----------+
2 獲取之前StringIndexer后的DataFrame中的Label信息
val df3 = spark.createDataFrame(Seq(
(0, 2.0),
(1, 1.0),
(2, 1.0),
(3, 0.0)
)).toDF("id", "index")
val converter2 = new IndexToString()
.setInputCol("index")
.setOutputCol("origin_col")
.setLabels(indexed.schema("categoryIndex").metadata.getMetadata("ml_attr").getStringArray("vals"))
val converted2 = converter2.transform(df3)
converted2.show(false)
+---+-----+----------+
|id |index|origin_col|
+---+-----+----------+
|0 |2.0 |b |
|1 |1.0 |c |
|2 |1.0 |c |
|3 |0.0 |a |
+---+-----+----------+
兩種方法都能得到正確的輸出。
完整的代碼可以參考github鏈接:
最終還是推薦詳細閱讀官方文檔,不過官方文檔真心有些粗糙,想要了解其中的原理,還是得靜下心來看看源碼。