Spark ML 之 如何將海量字符串映射為數字——StringIndexer/IndexToString


一、StringIndexer

在使用Spark MLlib協同過濾ALS API的時候發現Rating的三個參數:用戶id,商品名稱,商品打分,前兩個都需要是Int值。那么問題來了,當你的用戶id,商品名稱是String類型的情況下,我們必須尋找一個方法可以將海量String映射為數字類型。好在Spark MLlib可以answer這一切。

StringIndexer 將一列字符串標簽編碼成一列下標標簽,下標范圍在[0, 標簽數量),順序是標簽的出現頻率。所以最經常出現的標簽獲得的下標就是0。如果輸入列是數字的,我們會將其轉換成字符串,然后將字符串改為下標。當下游管道組成部分,比如說Estimator 或Transformer 使用將字符串轉換成下標的標簽時,你必須將組成部分的輸入列設置為這個將字符串轉換成下標后的列名。很多情況下,你可以使用setInputCol設置輸入列。

val spark = SparkSession.builder().appName("db").master("local[*]").getOrCreate()
val df = spark.createDataFrame(
      Seq((0,"a"),(1,"b"),(2,"c"),(3,"a"),(4,"a"),(5,"c"))
    ).toDF("id","category")
val indexer =new StringIndexer()
      .setInputCol("category")
      .setOutputCol("categoryIndex")
    val indexed = indexer.fit(df) // 訓練一個StringIndexer => StringIndexerModel
                  .transform(df)  // 用 StringIndexerModel transfer 數據集

此外,當你針對一個數據集訓練了一個StringIndexer,然后使用其去transform另一個數據集的時候,針對不可見的標簽StringIndexer 有兩個應對策略:

  •  throw an exception (which is the default)默認是拋出異常
  •  skip the row containing the unseen label entirely跳過包含不可見標簽的這一行
val df2 = spark.createDataFrame(
      Seq((0,"a"),(1,"b"),(2,"c"),(3,"d"),(4,"e"),(5,"f"))
    ).toDF("id","category")
val indexed2 = indexer.fit(df)
         .setHandleInvalid("skip") // 不匹配就跳過 .transform(df2) // 用 不匹配的stringIndexModel 來 transfer 數據集 indexed2.show()

 二、IndexToString

IndexToString 和StringIndexer是對稱的,它將一列下標標簽映射回一列包含原始字符串的標簽。常用的場合是使用StringIndexer生產下標,通過這些下標訓練模型,通過IndexToString從預測出的下標列重新獲得原始標簽。不過,你也可以使用你自己的標簽。

val converter =new IndexToString()
      .setInputCol("categoryIndex")
      .setOutputCol("originalCategory")
    val converted = converter.transform(indexed) // class IndexToString extends Transformer
    converted.select("id","originalCategory")
      .show()

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM