Spark-Mllib中各分類算法的java實現(簡易教程)


一.簡述

  Spark是當下非常流行的數據分析框架,而其中的機器學習包Mllib也是其諸多亮點之一,相信很多人也像我那樣想要快些上手spark。下面我將列出實現mllib分類的簡明代碼,代碼中將簡述訓練集和樣本集的結構,以及各分類算法的參數含義。分類模型包括朴素貝葉斯,SVM,決策樹以及隨機森林。

 

二.實現代碼

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import java.util.LinkedList;
import java.util.List;

import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;

import org.apache.spark.mllib.classification.SVMModel;
import org.apache.spark.mllib.classification.SVMWithSGD;

import java.util.HashMap;
import java.util.Map;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;

import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;

public class test {
    public static void main(String[] arg){
       //生成spark對象
        SparkConf conf = new SparkConf();
        conf.set("spark.testing.memory","2147480000");  // spark的運行配置,意指占用內存2G
        JavaSparkContext sc = new JavaSparkContext("local[*]", "Spark", conf);      //第一個參數為本地模式,[*]盡可能地獲取多的cpu;第二個是spark應用程序名,可以任意取;第三個為配置文件
        
        //訓練集生成
        LabeledPoint pos = new LabeledPoint(1.0, Vectors.dense(2.0, 3.0, 3.0));//規定數據結構為LabeledPoint,1.0為類別標號,Vectors.dense(2.0, 3.0, 3.0)為特征向量
        LabeledPoint neg = new LabeledPoint(0.0, Vectors.sparse(3, new int[] {2, 1,1}, new double[] {1.0, 1.0,1.0}));//特征值稀疏時,利用sparse構建
       List l = new LinkedList();//利用List存放訓練樣本
        l.add(neg);
        l.add(pos);
        JavaRDD<LabeledPoint>training = sc.parallelize(l); //RDD化,泛化類型為LabeledPoint 而不是List
        final NaiveBayesModel nb_model = NaiveBayes.train(training.rdd());        
        
        //測試集生成
        double []  d = {1,1,2};
        Vector v =  Vectors.dense(d);//測試對象為單個vector,或者是RDD化后的vector

        //朴素貝葉斯
      System.out.println(nb_model.predict(v));// 分類結果
      System.out.println(nb_model.predictProbabilities(v)); // 計算概率值

      
      //支持向量機
      int numIterations = 100;//迭代次數
      final SVMModel svm_model = SVMWithSGD.train(training.rdd(), numIterations);//構建模型
      System.out.println(svm_model.predict(v));

      //決策樹
      Integer numClasses = 2;//類別數量
      Map<Integer, Integer> categoricalFeaturesInfo = new HashMap();
      String impurity = "gini";//對於分類問題,我們可以用熵entropy或Gini來表示信息的無序程度 ,對於回歸問題,我們用方差(Variance)來表示無序程度,方差越大,說明數據間差異越大
      Integer maxDepth = 5;//最大樹深
      Integer maxBins = 32;//最大划分數
      final DecisionTreeModel tree_model = DecisionTree.trainClassifier(training, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins);//構建模型
      System.out.println("決策樹分類結果:");   
      System.out.println(tree_model.predict(v));
      
      //隨機森林
      Integer numTrees = 3; // Use more in practice.
      String featureSubsetStrategy = "auto"; // Let the algorithm choose.
      Integer seed = 12345;
      // Train a RandomForest model.
      final RandomForestModel forest_model = RandomForest.trainRegressor(training,
        categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed);//參數與決策數基本一致,除了seed
      System.out.println("隨機森林結果:");   
      System.out.println(forest_model.predict(v));
    }
  }

三.注意

1.利用spark進行數據分析時,數據一般要轉化為RDD(利用spark所提供接口讀取外部文件,一般會自動轉化為RDD,通過MapReduce處理同樣可以產生與接口匹配的訓練集)

2.訓練樣本統一為標簽向量(LabelPoint)。樣本集為List,但是轉化為RDD時,數據類型卻為JavaRDD<LabeledPoint>(模型訓練時,接口只接收數據類型為JavaRDD<LabeledPoint>)

3.分類predict返回結果為類別標簽,貝葉斯模型可返回屬於不同類的概率(python沒用該接口)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM