【weka】分類,cross-validation,數據


一、分類classifier

  如何利用weka里的類對數據集進行分類,要對數據集進行分類,第一步要指定數據集中哪一列做為類別,如果這一步忘記了(事實上經常會忘記)會出現“Class index is negative (not set)!”這個錯誤,設置某一列為類別用Instances類的成員方法setClassIndex,要設置最后一列為類別則可以用Instances類的numAttributes()成員方法得到屬性的個數再減1。

  然后選擇分類器,比較常用的分類器有J48,NaiveBayes,SMO(LibSVM有Java版的,可以在weka中使用,但要設置路徑),訓練分類器使用J48的buildClassifier(注意J48還有別的分類器它們都繼承自Classifier類,使用方法都差不多),分類數據用J48類中的classifyInstance方法,例中使用的數據集為contact-lenses.arff,分類結果為2.0,結果為2.0的原因是:首先用文本編輯器打開數據集,有一行為@attribute contact-lenses {soft, hard, none},而第一個樣本為young, myope, no, reduced, none,最后一列為類別,也就是contact-lences為類別,第一個樣本的類別為none,在屬性說明中none為第二個所以為2.0(從0開始數)。

 

二、評估Evaluation

  Evaluation類,這次只講一下最簡單的用法,首先初始化一個Evaluation對象,Evaluation類沒有無參的構造函數,一般用Instances對象作為構造函數的參數。

       如果沒有分開訓練集和測試集,可以使用Cross Validation方法EvaluationcrossValidateModel方法的四個參數分別為,第一個是分類器,第二個是在某個數據集上評價的數據集,第三個參數是交叉檢驗的次數(10是比較常見的),第四個是一個隨機數對象。

       如果有訓練集和測試集,可以使用Evaluation 類中的evaluateModel方法,方法中的參數為:第一個為一個訓練過的分類器,第二個參數是在某個數據集上評價的數據集。例中我為了簡單用訓練集再次做為測試集,希望大家不會糊塗。

       提醒大家一下,使用crossValidateModel時,分類器不需要先訓練,這其實也應該是常識了。

       Evaluation中提供了多種輸出方法,大家如果用過weka軟件,會發現方法輸出結果與軟件中某個顯示結果的是對應的。例中的三個方法toClassDetailsStringtoSummaryStringtoMatrixString比較常用。

 

三、特征選擇AttributeSelection

  用AttributeSelection進行特征選擇,它需要設置3個方面,第一:對屬性評價的類(自己到Weka軟件里看一下,英文Attribute Evaluator),第二:搜索的方式(自己到Weka軟件里看一下,英文Search Method),第三:就是你要進行特征選擇的數據集了。最后調用Filter的靜態方法userFilter,感覺寫的都是廢話,一看代碼就明白了。唯一值得一說的也就是別把AttributeSelection的包加錯了,代碼旁邊有注釋。

package org.ml;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.Random;

import weka.attributeSelection.CfsSubsetEval;
import weka.attributeSelection.GreedyStepwise;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.AttributeSelectedClassifier;
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.supervised.attribute.AttributeSelection;

public class Test {

    public static Instances getFileInstances(String fileName)
            throws FileNotFoundException, IOException {
        Instances m_Instances = new Instances(new BufferedReader(
                new FileReader(fileName)));
        m_Instances.setClassIndex(m_Instances.numAttributes() - 1);
        return m_Instances;
    }

    public static Evaluation crossValidation(Instances m_Instances,
            Classifier classifier, int numFolds) throws Exception {
        Evaluation evaluation = new Evaluation(m_Instances);
        evaluation.crossValidateModel(classifier, m_Instances, numFolds,
                new Random(1));
        return evaluation;
    }
    
    public static Evaluation evaluateTestData(Instances m_Instances, Classifier classifier) throws Exception {
        int split = (int) (m_Instances.numInstances() * 0.6);
        Instances traindata = new Instances(m_Instances, 0, split);
        Instances testdata = new Instances(m_Instances, split, m_Instances.numInstances() - split);
        classifier.buildClassifier(traindata);
        //下面一行是m_Instances,或traindata,或testdata都沒關系,因為Evaluation構造方法要的只是instance的結構,比如屬性
        Evaluation evaluation = new Evaluation(m_Instances);
        evaluation.evaluateModel(classifier, testdata);
        return evaluation;
    }
    
    public static Instances selectAttrUseFilter(Instances m_Instances) throws Exception {
        AttributeSelection filter = new AttributeSelection();
        filter.setEvaluator(new CfsSubsetEval());
        filter.setSearch(new GreedyStepwise());
        filter.setInputFormat(m_Instances);
        return Filter.useFilter(m_Instances, filter);
    }
    
    public static void selectAttrUseMC(Instances m_Instances, Classifier base) throws Exception {
        AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
        classifier.setClassifier(base);
        classifier.setEvaluator(new CfsSubsetEval());
        classifier.setSearch(new GreedyStepwise());
        Evaluation evaluation = new Evaluation(m_Instances);
        evaluation.crossValidateModel(classifier, m_Instances, 10, new Random(1));
        System.out.println(evaluation.toSummaryString());
    }
    
    public static void printEvalDetail(Evaluation evaluation) throws Exception {
        System.out.println(evaluation.toClassDetailsString());
        System.out.println(evaluation.toSummaryString());
        System.out.println(evaluation.toMatrixString());
    }

    public static void main(String[] args) throws Exception {
        
        Instances data = getFileInstances("C:\\Program Files\\Weka-3-7\\data\\soybean.arff");
        //交叉驗證
        Evaluation crossEvaluation = crossValidation(data, new J48(), 10);
        printEvalDetail(crossEvaluation);
        
        System.out.println("=====================================");
        //一般分類器分類,部分數據用於train,部分用於test
        Evaluation testEvaluation = evaluateTestData(data, new J48());
        printEvalDetail(testEvaluation);
        
        System.out.println("=====================================");
        //特征篩選
        Instances newData = selectAttrUseFilter(data);
        System.out.println("Oral data:" + data.numAttributes());
        System.out.println("selected data:" + newData.numAttributes());
        testEvaluation = evaluateTestData(newData, new J48());
        printEvalDetail(testEvaluation);
        
        System.out.println("=====================================");
        selectAttrUseMC(data, new J48());
        

//        System.out.println("=====================================");
//        J48 classifer = new J48();
//        classifer.buildClassifier(data);
//        for (int i = 0; i < data.numInstances(); i++) {
//        //輸出每個樣例被分到的類別,如果是二分,分別表示為0和1
// System.out.println(data.instance(i) + " === " + classifer.classifyInstance(data.instance(i))); // } } }

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM