一.簡述
Spark是當下非常流行的數據分析框架,而其中的機器學習包Mllib也是其諸多亮點之一,相信很多人也像我那樣想要快些上手spark。下面我將列出實現mllib分類的簡明代碼,代碼中將簡述訓練集和樣本集的結構,以及各分類算法的參數含義。分類模型包括朴素貝葉斯,SVM,決策樹以及隨機森林。
二.實現代碼
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import java.util.LinkedList;
import java.util.List;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
public class test {
public static void main(String[] arg){
//生成spark對象
SparkConf conf = new SparkConf();
conf.set("spark.testing.memory","2147480000"); // spark的運行配置,意指占用內存2G
JavaSparkContext sc = new JavaSparkContext("local[*]", "Spark", conf); //第一個參數為本地模式,[*]盡可能地獲取多的cpu;第二個是spark應用程序名,可以任意取;第三個為配置文件
//訓練集生成
LabeledPoint pos = new LabeledPoint(1.0, Vectors.dense(2.0, 3.0, 3.0));//規定數據結構為LabeledPoint,1.0為類別標號,Vectors.dense(2.0, 3.0, 3.0)為特征向量
LabeledPoint neg = new LabeledPoint(0.0, Vectors.sparse(3, new int[] {2, 1,1}, new double[] {1.0, 1.0,1.0}));//特征值稀疏時,利用sparse構建
List l = new LinkedList();//利用List存放訓練樣本
l.add(neg);
l.add(pos);
JavaRDD<LabeledPoint>training = sc.parallelize(l); //RDD化,泛化類型為LabeledPoint 而不是List
final NaiveBayesModel nb_model = NaiveBayes.train(training.rdd());
//測試集生成
double [] d = {1,1,2};
Vector v = Vectors.dense(d);//測試對象為單個vector,或者是RDD化后的vector
//朴素貝葉斯
System.out.println(nb_model.predict(v));// 分類結果
System.out.println(nb_model.predictProbabilities(v)); // 計算概率值
//支持向量機
int numIterations = 100;//迭代次數
final SVMModel svm_model = SVMWithSGD.train(training.rdd(), numIterations);//構建模型
System.out.println(svm_model.predict(v));
//決策樹
Integer numClasses = 2;//類別數量
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap();
String impurity = "gini";//對於分類問題,我們可以用熵entropy或Gini來表示信息的無序程度 ,對於回歸問題,我們用方差(Variance)來表示無序程度,方差越大,說明數據間差異越大
Integer maxDepth = 5;//最大樹深
Integer maxBins = 32;//最大划分數
final DecisionTreeModel tree_model = DecisionTree.trainClassifier(training, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins);//構建模型
System.out.println("決策樹分類結果:");
System.out.println(tree_model.predict(v));
//隨機森林
Integer numTrees = 3; // Use more in practice.
String featureSubsetStrategy = "auto"; // Let the algorithm choose.
Integer seed = 12345;
// Train a RandomForest model.
final RandomForestModel forest_model = RandomForest.trainRegressor(training,
categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed);//參數與決策數基本一致,除了seed
System.out.println("隨機森林結果:");
System.out.println(forest_model.predict(v));
}
}
三.注意
1.利用spark進行數據分析時,數據一般要轉化為RDD(利用spark所提供接口讀取外部文件,一般會自動轉化為RDD,通過MapReduce處理同樣可以產生與接口匹配的訓練集)
2.訓練樣本統一為標簽向量(LabelPoint)。樣本集為List,但是轉化為RDD時,數據類型卻為JavaRDD<LabeledPoint>(模型訓練時,接口只接收數據類型為JavaRDD<LabeledPoint>)
3.分類predict返回結果為類別標簽,貝葉斯模型可返回屬於不同類的概率(python沒用該接口)
