關於spark的mllib學習總結(Java版)


本篇博客主要講述如何利用spark的mliib構建機器學習模型並預測新的數據,具體的流程如下圖所示:

 

加載數據 對於數據的加載或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的數據是采用spark中提供的數據sample_libsvm_data.txt,其有一百個數據樣本,658個特征。具體的數據形式如圖所示: 

加載libsvm 

JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();

LabeledPoint數據類型是對應與libsvmfile格式文件, 具體格式為: Lable(double類型),vector(Vector類型) 轉化dataFrame數據類型 

JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow()); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()), }); SQLContext jsql = new SQLContext(sc); DataFrame df = jsql.createDataFrame(jrow, schema);

DataFrame:DataFrame是一個以命名列方式組織的分布式數據集。在概念上,它跟關系型數據庫中的一張表或者1個Python(或者R)中的data frame一樣,但是比他們更優化。DataFrame可以根據結構化的數據文件、hive表、外部數據庫或者已經存在的RDD構造。 SQLContext:spark sql所有功能的入口是SQLContext類,或者SQLContext的子類。為了創建一個基本的SQLContext,需要一個SparkContext。 特征提取 特征歸一化處理 

StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true); DataFrame scalerDF = scaler.fit(df).transform(df); scaler.save(this.scalerModelPath);

利用卡方統計做特征提取 

ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures"); ChiSqSelectorModel chiModel = selector.fit(scalerDF); DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures"); chiModel.save(this.featureSelectedModelPath);

訓練機器學習模型(以SVM為例)

//轉化為LabeledPoint數據類型, 訓練模型
JavaRDD<Row> selectedrows = selectedDF.javaRDD(); JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel()); //訓練SVM模型, 並保存
int numIteration = 200; SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration); model.clearThreshold(); model.save(sc, this.mlModelPath); // LabeledPoint數據類型轉化為Row
static class LabeledPointToRow implements Function<LabeledPoint, Row> { public Row call(LabeledPoint p) throws Exception { double label = p.label(); Vector vector = p.features(); return RowFactory.create(label, vector); } } //Rows數據類型轉化為LabeledPoint
static class RowToLabel implements Function<Row, LabeledPoint> { public LabeledPoint call(Row r) throws Exception { Vector features = r.getAs(1); double label = r.getDouble(0); return new LabeledPoint(label, features); } }

測試新的樣本 測試新的樣本前,需要將樣本做數據的轉化和特征提取的工作,所有剛剛訓練模型的過程中,除了保存機器學習模型,還需要保存特征提取的中間模型。具體代碼如下:

//初始化spark
SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local"); conf.set("spark.testing.memory", "2147480000"); SparkContext sc = new SparkContext(conf); //加載測試數據
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD(); //轉化DataFrame數據類型
JavaRDD<Row> jrow =testData.map(new LabeledPointToRow()); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()), }); SQLContext jsql = new SQLContext(sc); DataFrame df = jsql.createDataFrame(jrow, schema); //數據規范化
StandardScaler scaler = StandardScaler.load(this.scalerModelPath); DataFrame scalerDF = scaler.fit(df).transform(df); //特征選取
ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath); DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");

測試數據集

SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath); JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ; predictResult.collect(); static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> { SVMModel model; public Prediction(SVMModel model){ this.model = model; } public Tuple2<Double, Double> call(LabeledPoint p) throws Exception { Double score = model.predict(p.features()); return new Tuple2<Double , Double>(score, p.label()); } }

計算准確率

double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count(); System.out.println(accuracy); static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> { public Boolean call(Tuple2<Double, Double> t) throws Exception { double score = t._1(); double label = t._2(); System.out.print("score:" + score + ", label:"+ label); if(score >= 0.0 && label >= 0.0) return true; else if(score < 0.0 && label < 0.0) return true; else return false; } }

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM