在使用Pipeline串聯多個stage時model和非model的區別


train.csv數據:

id,name,age,sex
1,lyy,20,F
2,rdd,20,M
3,nyc,18,M
4,mzy,10,M

 

數據讀取:

 1 SparkSession  spark = SparkSession.builder().enableHiveSupport()
 2                     .getOrCreate();
 3         Dataset<Row> dataset = spark
 4                 .read()
 5                 .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat")
 6                 .option("header", true)
 7                 .option("inferSchema", true)
 8                 .option("delimiter", ",")
 9                 //.load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/SanFranciscoCrime/document/kaggle-舊金山犯罪分類/train-new.csv") //PreProcess1
10                 .load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/DataPreprocessing/document/train.csv") //PreProcess2
11                 .persist();

 

 1     public static void PreProcess2(Dataset<Row> data) {
 2         
 3                 data.printSchema();
 4                 // 重新索引標簽值
 5                 StringIndexerModel labelIndexer = new StringIndexer()
 6                 .setInputCol("sex")
 7                 .setOutputCol("label")
 8                 .fit(data);
 9                 
10                 StringIndexerModel nameIndexer = new StringIndexer()
11                 .setInputCol("name")
12                 .setOutputCol("namenum")
13                 .fit(data);
14 
15                 
16                 /*  會報錯:Exception in thread "main" java.lang.IllegalArgumentException: Field "namenum" does not exist.
17                  * 原因是:Model類型調用fit時,要求數據集中必須包含InputCol所指定的列名
18                  * 不會將Pipeline某個stage的輸出作為InputCol,即使那個stage的OutputCol指定的列名與其相同也不行
19                  * StringIndexerModel name1Indexer = new StringIndexer()
20                 .setInputCol("namenum")
21                 .setOutputCol("namenum1")
22                 .fit(data);*/
23                 
24                 
25                 /* 錯誤原因StringIndexerModel錯誤一樣,features並不是data的列
26                  * VectorIndexerModel featureIndexer = new VectorIndexer()
27                     .setInputCol("features")
28                     .setOutputCol("indexfeatures")
29                     .setMaxCategories(4)
30                     .fit(data);*/
31                 
32                 //成功
33                 //原因說明:非model時,轉換器不會調用fit,而會使用Pipeline某個stage的輸出作為InputCol
34                 //由於stage[2]即 assembler已經生成features,故而該處直接使用;
35                 //但是該類型時不能單獨使用,必須依賴Pipeline
36                 VectorIndexer featureIndexer = new VectorIndexer()
37                 .setInputCol("features")
38                 .setOutputCol("indexfeatures")
39                 .setMaxCategories(4);
40                 
41                 //由上述分析可知,該處輸入的列可以是多個stage的輸出組成,因為VectorAssembler非model
42                 //因此可以使用中間生成結果,且可以使用多個
43                 VectorAssembler assembler = new VectorAssembler()
44                 .setInputCols("id,namenum,age".split(","))
45                .setOutputCol("features");
46                 
47                 //這里的stage的順序很重要,一定按照依賴關系順序放入,如下順序就會報錯:
48                 //Exception in thread "main" java.lang.IllegalArgumentException: Field "features" does not exist.
49                 //Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,featureIndexer,assembler});
50                 
51                 //將featureIndexer放到assembler即可
52                 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer});
53 
54                 // Train model. This also runs the indexers.
55                 PipelineModel model = pipeline.fit(data);
56 
57                 // Make predictions.
58                 Dataset<Row> result = model.transform(data);
59                 
60                 result.show(10, false);
61                 
62     }

 

root
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- sex: string (nullable = true)

+---+----+---+---+-----+-------+--------------+-------------+
|id |name|age|sex|label|namenum|features |indexfeatures|
+---+----+---+---+-----+-------+--------------+-------------+
|1 |lyy |20 |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
|2 |rdd |20 |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
|3 |nyc |18 |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
|4 |mzy |10 |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
+---+----+---+---+-----+-------+--------------+-------------+

綜上分析,可以將原有代碼做一簡化:

 1 public static void PreProcess2(Dataset<Row> data) {
 2         
 3                 data.printSchema();
 4                 // 重新索引標簽值
 5                 StringIndexer labelIndexer = new StringIndexer()
 6                 .setInputCol("sex")
 7                 .setOutputCol("label");
 8                 
 9                 StringIndexer nameIndexer = new StringIndexer()
10                 .setInputCol("name")
11                 .setOutputCol("namenum");
12 
13                 VectorIndexer featureIndexer = new VectorIndexer()
14                 .setInputCol("features")
15                 .setOutputCol("indexfeatures")
16                 .setMaxCategories(4);
17                 
18             
19                 VectorAssembler assembler = new VectorAssembler()
20                 .setInputCols("id,namenum,age".split(","))
21                .setOutputCol("features");
22 
23                 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer});
24 
25                 // Train model. This also runs the indexers.
26                 PipelineModel model = pipeline.fit(data);  //以這里的data為基准數據
27 
28                 // Make predictions.
29                 Dataset<Row> result = model.transform(data);
30                 
31                 result.show(10, false);
32                 
33     }

運行結果:

root
 |-- id: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- sex: string (nullable = true)

+---+----+---+---+-----+-------+--------------+-------------+
|id |name|age|sex|label|namenum|features      |indexfeatures|
+---+----+---+---+-----+-------+--------------+-------------+
|1  |lyy |20 |F  |1.0  |1.0    |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
|2  |rdd |20 |M  |0.0  |2.0    |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
|3  |nyc |18 |M  |0.0  |0.0    |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
|4  |mzy |10 |M  |0.0  |3.0    |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
+---+----+---+---+-----+-------+--------------+-------------+

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM