train.csv數據:
id,name,age,sex
1,lyy,20,F
2,rdd,20,M
3,nyc,18,M
4,mzy,10,M
數據讀取:
1 SparkSession spark = SparkSession.builder().enableHiveSupport() 2 .getOrCreate(); 3 Dataset<Row> dataset = spark 4 .read() 5 .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") 6 .option("header", true) 7 .option("inferSchema", true) 8 .option("delimiter", ",") 9 //.load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/SanFranciscoCrime/document/kaggle-舊金山犯罪分類/train-new.csv") //PreProcess1 10 .load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/DataPreprocessing/document/train.csv") //PreProcess2 11 .persist();
1 public static void PreProcess2(Dataset<Row> data) { 2 3 data.printSchema(); 4 // 重新索引標簽值 5 StringIndexerModel labelIndexer = new StringIndexer() 6 .setInputCol("sex") 7 .setOutputCol("label") 8 .fit(data); 9 10 StringIndexerModel nameIndexer = new StringIndexer() 11 .setInputCol("name") 12 .setOutputCol("namenum") 13 .fit(data); 14 15 16 /* 會報錯:Exception in thread "main" java.lang.IllegalArgumentException: Field "namenum" does not exist. 17 * 原因是:Model類型調用fit時,要求數據集中必須包含InputCol所指定的列名 18 * 不會將Pipeline某個stage的輸出作為InputCol,即使那個stage的OutputCol指定的列名與其相同也不行 19 * StringIndexerModel name1Indexer = new StringIndexer() 20 .setInputCol("namenum") 21 .setOutputCol("namenum1") 22 .fit(data);*/ 23 24 25 /* 錯誤原因StringIndexerModel錯誤一樣,features並不是data的列 26 * VectorIndexerModel featureIndexer = new VectorIndexer() 27 .setInputCol("features") 28 .setOutputCol("indexfeatures") 29 .setMaxCategories(4) 30 .fit(data);*/ 31 32 //成功 33 //原因說明:非model時,轉換器不會調用fit,而會使用Pipeline某個stage的輸出作為InputCol 34 //由於stage[2]即 assembler已經生成features,故而該處直接使用; 35 //但是該類型時不能單獨使用,必須依賴Pipeline 36 VectorIndexer featureIndexer = new VectorIndexer() 37 .setInputCol("features") 38 .setOutputCol("indexfeatures") 39 .setMaxCategories(4); 40 41 //由上述分析可知,該處輸入的列可以是多個stage的輸出組成,因為VectorAssembler非model 42 //因此可以使用中間生成結果,且可以使用多個 43 VectorAssembler assembler = new VectorAssembler() 44 .setInputCols("id,namenum,age".split(",")) 45 .setOutputCol("features"); 46 47 //這里的stage的順序很重要,一定按照依賴關系順序放入,如下順序就會報錯: 48 //Exception in thread "main" java.lang.IllegalArgumentException: Field "features" does not exist. 49 //Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,featureIndexer,assembler}); 50 51 //將featureIndexer放到assembler即可 52 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer}); 53 54 // Train model. This also runs the indexers. 55 PipelineModel model = pipeline.fit(data); 56 57 // Make predictions. 58 Dataset<Row> result = model.transform(data); 59 60 result.show(10, false); 61 62 }
root
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- sex: string (nullable = true)
+---+----+---+---+-----+-------+--------------+-------------+
|id |name|age|sex|label|namenum|features |indexfeatures|
+---+----+---+---+-----+-------+--------------+-------------+
|1 |lyy |20 |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
|2 |rdd |20 |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
|3 |nyc |18 |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
|4 |mzy |10 |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
+---+----+---+---+-----+-------+--------------+-------------+
綜上分析,可以將原有代碼做一簡化:
1 public static void PreProcess2(Dataset<Row> data) { 2 3 data.printSchema(); 4 // 重新索引標簽值 5 StringIndexer labelIndexer = new StringIndexer() 6 .setInputCol("sex") 7 .setOutputCol("label"); 8 9 StringIndexer nameIndexer = new StringIndexer() 10 .setInputCol("name") 11 .setOutputCol("namenum"); 12 13 VectorIndexer featureIndexer = new VectorIndexer() 14 .setInputCol("features") 15 .setOutputCol("indexfeatures") 16 .setMaxCategories(4); 17 18 19 VectorAssembler assembler = new VectorAssembler() 20 .setInputCols("id,namenum,age".split(",")) 21 .setOutputCol("features"); 22 23 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer}); 24 25 // Train model. This also runs the indexers. 26 PipelineModel model = pipeline.fit(data); //以這里的data為基准數據 27 28 // Make predictions. 29 Dataset<Row> result = model.transform(data); 30 31 result.show(10, false); 32 33 }
運行結果:
root |-- id: integer (nullable = true) |-- name: string (nullable = true) |-- age: integer (nullable = true) |-- sex: string (nullable = true) +---+----+---+---+-----+-------+--------------+-------------+ |id |name|age|sex|label|namenum|features |indexfeatures| +---+----+---+---+-----+-------+--------------+-------------+ |1 |lyy |20 |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]| |2 |rdd |20 |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]| |3 |nyc |18 |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]| |4 |mzy |10 |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]| +---+----+---+---+-----+-------+--------------+-------------+