一.簡述
Spark是當下非常流行的數據分析框架,而其中的機器學習包Mllib也是其諸多亮點之一,相信很多人也像我那樣想要快些上手spark。下面我將列出實現mllib分類的簡明代碼,代碼中將簡述訓練集和樣本集的結構,以及各分類算法的參數含義。分類模型包括朴素貝葉斯,SVM,決策樹以及隨機森林。
二.實現代碼
import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import java.util.LinkedList; import java.util.List; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.classification.SVMModel; import org.apache.spark.mllib.classification.SVMWithSGD; import java.util.HashMap; import java.util.Map; import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.model.DecisionTreeModel; import org.apache.spark.mllib.tree.RandomForest; import org.apache.spark.mllib.tree.model.RandomForestModel; public class test { public static void main(String[] arg){ //生成spark對象 SparkConf conf = new SparkConf(); conf.set("spark.testing.memory","2147480000"); // spark的運行配置,意指占用內存2G JavaSparkContext sc = new JavaSparkContext("local[*]", "Spark", conf); //第一個參數為本地模式,[*]盡可能地獲取多的cpu;第二個是spark應用程序名,可以任意取;第三個為配置文件 //訓練集生成 LabeledPoint pos = new LabeledPoint(1.0, Vectors.dense(2.0, 3.0, 3.0));//規定數據結構為LabeledPoint,1.0為類別標號,Vectors.dense(2.0, 3.0, 3.0)為特征向量 LabeledPoint neg = new LabeledPoint(0.0, Vectors.sparse(3, new int[] {2, 1,1}, new double[] {1.0, 1.0,1.0}));//特征值稀疏時,利用sparse構建 List l = new LinkedList();//利用List存放訓練樣本 l.add(neg); l.add(pos); JavaRDD<LabeledPoint>training = sc.parallelize(l); //RDD化,泛化類型為LabeledPoint 而不是List final NaiveBayesModel nb_model = NaiveBayes.train(training.rdd()); //測試集生成 double [] d = {1,1,2}; Vector v = Vectors.dense(d);//測試對象為單個vector,或者是RDD化后的vector //朴素貝葉斯 System.out.println(nb_model.predict(v));// 分類結果 System.out.println(nb_model.predictProbabilities(v)); // 計算概率值 //支持向量機 int numIterations = 100;//迭代次數 final SVMModel svm_model = SVMWithSGD.train(training.rdd(), numIterations);//構建模型 System.out.println(svm_model.predict(v)); //決策樹 Integer numClasses = 2;//類別數量 Map<Integer, Integer> categoricalFeaturesInfo = new HashMap(); String impurity = "gini";//對於分類問題,我們可以用熵entropy或Gini來表示信息的無序程度 ,對於回歸問題,我們用方差(Variance)來表示無序程度,方差越大,說明數據間差異越大 Integer maxDepth = 5;//最大樹深 Integer maxBins = 32;//最大划分數 final DecisionTreeModel tree_model = DecisionTree.trainClassifier(training, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins);//構建模型 System.out.println("決策樹分類結果:"); System.out.println(tree_model.predict(v)); //隨機森林 Integer numTrees = 3; // Use more in practice. String featureSubsetStrategy = "auto"; // Let the algorithm choose. Integer seed = 12345; // Train a RandomForest model. final RandomForestModel forest_model = RandomForest.trainRegressor(training, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed);//參數與決策數基本一致,除了seed System.out.println("隨機森林結果:"); System.out.println(forest_model.predict(v)); } }
三.注意
1.利用spark進行數據分析時,數據一般要轉化為RDD(利用spark所提供接口讀取外部文件,一般會自動轉化為RDD,通過MapReduce處理同樣可以產生與接口匹配的訓練集)
2.訓練樣本統一為標簽向量(LabelPoint)。樣本集為List,但是轉化為RDD時,數據類型卻為JavaRDD<LabeledPoint>(模型訓練時,接口只接收數據類型為JavaRDD<LabeledPoint>)
3.分類predict返回結果為類別標簽,貝葉斯模型可返回屬於不同類的概率(python沒用該接口)