一、分類classifier
如何利用weka里的類對數據集進行分類,要對數據集進行分類,第一步要指定數據集中哪一列做為類別,如果這一步忘記了(事實上經常會忘記)會出現“Class index is negative (not set)!”這個錯誤,設置某一列為類別用Instances類的成員方法setClassIndex,要設置最后一列為類別則可以用Instances類的numAttributes()成員方法得到屬性的個數再減1。
然后選擇分類器,比較常用的分類器有J48,NaiveBayes,SMO(LibSVM有Java版的,可以在weka中使用,但要設置路徑),訓練分類器使用J48的buildClassifier(注意J48還有別的分類器它們都繼承自Classifier類,使用方法都差不多),分類數據用J48類中的classifyInstance方法,例中使用的數據集為contact-lenses.arff,分類結果為2.0,結果為2.0的原因是:首先用文本編輯器打開數據集,有一行為@attribute contact-lenses {soft, hard, none},而第一個樣本為young, myope, no, reduced, none,最后一列為類別,也就是contact-lences為類別,第一個樣本的類別為none,在屬性說明中none為第二個所以為2.0(從0開始數)。
二、評估Evaluation
Evaluation類,這次只講一下最簡單的用法,首先初始化一個Evaluation對象,Evaluation類沒有無參的構造函數,一般用Instances對象作為構造函數的參數。
如果沒有分開訓練集和測試集,可以使用Cross Validation方法,Evaluation中crossValidateModel方法的四個參數分別為,第一個是分類器,第二個是在某個數據集上評價的數據集,第三個參數是交叉檢驗的次數(10是比較常見的),第四個是一個隨機數對象。
如果有訓練集和測試集,可以使用Evaluation 類中的evaluateModel方法,方法中的參數為:第一個為一個訓練過的分類器,第二個參數是在某個數據集上評價的數據集。例中我為了簡單用訓練集再次做為測試集,希望大家不會糊塗。
提醒大家一下,使用crossValidateModel時,分類器不需要先訓練,這其實也應該是常識了。
Evaluation中提供了多種輸出方法,大家如果用過weka軟件,會發現方法輸出結果與軟件中某個顯示結果的是對應的。例中的三個方法toClassDetailsString,toSummaryString,toMatrixString比較常用。
三、特征選擇AttributeSelection
用AttributeSelection進行特征選擇,它需要設置3個方面,第一:對屬性評價的類(自己到Weka軟件里看一下,英文Attribute Evaluator),第二:搜索的方式(自己到Weka軟件里看一下,英文Search Method),第三:就是你要進行特征選擇的數據集了。最后調用Filter的靜態方法userFilter,感覺寫的都是廢話,一看代碼就明白了。唯一值得一說的也就是別把AttributeSelection的包加錯了,代碼旁邊有注釋。
package org.ml; import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.Random; import weka.attributeSelection.CfsSubsetEval; import weka.attributeSelection.GreedyStepwise; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.meta.AttributeSelectedClassifier; import weka.classifiers.trees.J48; import weka.core.Instances; import weka.filters.Filter; import weka.filters.supervised.attribute.AttributeSelection; public class Test { public static Instances getFileInstances(String fileName) throws FileNotFoundException, IOException { Instances m_Instances = new Instances(new BufferedReader( new FileReader(fileName))); m_Instances.setClassIndex(m_Instances.numAttributes() - 1); return m_Instances; } public static Evaluation crossValidation(Instances m_Instances, Classifier classifier, int numFolds) throws Exception { Evaluation evaluation = new Evaluation(m_Instances); evaluation.crossValidateModel(classifier, m_Instances, numFolds, new Random(1)); return evaluation; } public static Evaluation evaluateTestData(Instances m_Instances, Classifier classifier) throws Exception { int split = (int) (m_Instances.numInstances() * 0.6); Instances traindata = new Instances(m_Instances, 0, split); Instances testdata = new Instances(m_Instances, split, m_Instances.numInstances() - split); classifier.buildClassifier(traindata); //下面一行是m_Instances,或traindata,或testdata都沒關系,因為Evaluation構造方法要的只是instance的結構,比如屬性 Evaluation evaluation = new Evaluation(m_Instances); evaluation.evaluateModel(classifier, testdata); return evaluation; } public static Instances selectAttrUseFilter(Instances m_Instances) throws Exception { AttributeSelection filter = new AttributeSelection(); filter.setEvaluator(new CfsSubsetEval()); filter.setSearch(new GreedyStepwise()); filter.setInputFormat(m_Instances); return Filter.useFilter(m_Instances, filter); } public static void selectAttrUseMC(Instances m_Instances, Classifier base) throws Exception { AttributeSelectedClassifier classifier = new AttributeSelectedClassifier(); classifier.setClassifier(base); classifier.setEvaluator(new CfsSubsetEval()); classifier.setSearch(new GreedyStepwise()); Evaluation evaluation = new Evaluation(m_Instances); evaluation.crossValidateModel(classifier, m_Instances, 10, new Random(1)); System.out.println(evaluation.toSummaryString()); } public static void printEvalDetail(Evaluation evaluation) throws Exception { System.out.println(evaluation.toClassDetailsString()); System.out.println(evaluation.toSummaryString()); System.out.println(evaluation.toMatrixString()); } public static void main(String[] args) throws Exception { Instances data = getFileInstances("C:\\Program Files\\Weka-3-7\\data\\soybean.arff"); //交叉驗證 Evaluation crossEvaluation = crossValidation(data, new J48(), 10); printEvalDetail(crossEvaluation); System.out.println("====================================="); //一般分類器分類,部分數據用於train,部分用於test Evaluation testEvaluation = evaluateTestData(data, new J48()); printEvalDetail(testEvaluation); System.out.println("====================================="); //特征篩選 Instances newData = selectAttrUseFilter(data); System.out.println("Oral data:" + data.numAttributes()); System.out.println("selected data:" + newData.numAttributes()); testEvaluation = evaluateTestData(newData, new J48()); printEvalDetail(testEvaluation); System.out.println("====================================="); selectAttrUseMC(data, new J48()); // System.out.println("====================================="); // J48 classifer = new J48(); // classifer.buildClassifier(data); // for (int i = 0; i < data.numInstances(); i++) {
// //輸出每個樣例被分到的類別,如果是二分,分別表示為0和1 // System.out.println(data.instance(i) + " === " + classifer.classifyInstance(data.instance(i))); // } } }