Eclipse下mallet使用的方法


Mallet是Umass大牛開發的一個關於統計自然語言處理的l的開源庫,很好的一個東西。可以用來學topic model,訓練ME模型等。對於開發者來說,其官網的技術文檔是非常有效的。

mallet下載地址,瀏覽開發者文檔,只需點擊相應的“Developer's Guide”。

下面以開發一個簡單的最大熵分類模型為例,可參考文檔

首先下載mallet工具包,該工具包中包含代碼和jar包,簡單起見,我們導入mallet-2.0.7\dist下的mallet.jar和mallet-deps.jar,導入jar包過程為:項目右擊->Properties->Java Build Path->Libraries,點擊“Add JARs”,在路徑中選取相應的jar包即可。

新建Maxent類,代碼如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.pipe.iterator.CsvIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import cc.mallet.util.Randoms;

public class Maxent implements Serializable{
    
    //Train a classifier
    public static Classifier trainClassifier(InstanceList trainingInstances) {
        // Here we use a maximum entropy (ie polytomous logistic regression) classifier.                                                 
        ClassifierTrainer trainer = new MaxEntTrainer();
        return trainer.train(trainingInstances);
    }
    
    //save a trained classifier/write a trained classifier to disk
    public void saveClassifier(Classifier classifier,String savePath) throws IOException{
        ObjectOutputStream oos=new ObjectOutputStream(new FileOutputStream(savePath));
        oos.writeObject(classifier);
        oos.flush();
        oos.close();        
    }
    
    //restore a saved classifier
    public Classifier loadClassifier(String savedPath) throws FileNotFoundException, IOException, ClassNotFoundException{                                              
        // Here we load a serialized classifier from a file.
        Classifier classifier;
        ObjectInputStream ois = new ObjectInputStream (new FileInputStream (new File(savedPath)));
        classifier = (Classifier) ois.readObject();
        ois.close();
        return classifier;
    }
    
    //predict & evaluate
    public String predict(Classifier classifier,Instance testInstance){
        Labeling labeling = classifier.classify(testInstance).getLabeling();
        Label label = labeling.getBestLabel();
        return (String)label.getEntry();
    }
    
    public void evaluate(Classifier classifier, String testFilePath) throws IOException {
        InstanceList testInstances = new InstanceList(classifier.getInstancePipe());                                                                                                                                                                
        
        //format of input data:[name] [label] [data ... ]                                                                    
        CsvIterator reader = new CsvIterator(new FileReader(new File(testFilePath)),"(\\w+)\\s+(\\w+)\\s+(.*)",3, 2, 1);  // (data, label, name) field indices               

        // Add all instances loaded by the iterator to our instance list
        testInstances.addThruPipe(reader);
        Trial trial = new Trial(classifier, testInstances);

        //evaluation metrics.precision, recall, and F1
        System.out.println("Accuracy: " + trial.getAccuracy());                                                      
        System.out.println("F1 for class 'good': " + trial.getF1("good"));
        System.out.println("Precision for class '" +
                           classifier.getLabelAlphabet().lookupLabel(1) + "': " +
                           trial.getPrecision(1));
    }

    //perform n-fold cross validation
     public static Trial testTrainSplit(MaxEntTrainer trainer, InstanceList instances) {
         int TRAINING = 0;
         int TESTING = 1;
         int VALIDATION = 2;
     
         // Split the input list into training (90%) and testing (10%) lists.
         InstanceList[] instanceLists = instances.split(new Randoms(), new double[] {0.9, 0.1, 0.0});
         Classifier classifier = trainClassifier(instanceLists[TRAINING]);
         return new Trial(classifier, instanceLists[TESTING]);
      }
     
    public static void main(String[] args) throws FileNotFoundException,IOException{
        //define training samples
        Alphabet featureAlphabet = new Alphabet();//特征詞典
        LabelAlphabet targetAlphabet = new LabelAlphabet();//類標詞典
        targetAlphabet.lookupIndex("positive");
        targetAlphabet.lookupIndex("negative");
        targetAlphabet.lookupIndex("neutral");
        targetAlphabet.stopGrowth();
        featureAlphabet.lookupIndex("f1");
        featureAlphabet.lookupIndex("f2");
        featureAlphabet.lookupIndex("f3");
        InstanceList trainingInstances = new InstanceList (featureAlphabet,targetAlphabet);//實例集對象
        final int size = targetAlphabet.size();
        double[] featureValues1 = {1.0, 0.0, 0.0};
        double[] featureValues2 = {2.0, 0.0, 0.0};
        double[] featureValues3 = {0.0, 1.0, 0.0};
        double[] featureValues4 = {0.0, 0.0, 1.0};
        double[] featureValues5 = {0.0, 0.0, 3.0};
        String[] targetValue = {"positive","positive","neutral","negative","negative"};
        List<double[]> featureValues = Arrays.asList(featureValues1,featureValues2,featureValues3,featureValues4,featureValues5); 
        int i = 0;
        for(double[]featureValue:featureValues){
            FeatureVector featureVector = new FeatureVector(featureAlphabet,
                    (String[])targetAlphabet.toArray(new String[size]),featureValue);//change list to array
            Instance instance = new Instance (featureVector,targetAlphabet.lookupLabel(targetValue[i]), "xxx",null);
            i++;
            trainingInstances.add(instance);
        }
         
        Maxent maxent = new Maxent();
        Classifier maxentclassifier = maxent.trainClassifier(trainingInstances);
        //loading test examples
        double[] testfeatureValues = {0.5, 0.5, 6.0};
        FeatureVector testfeatureVector = new FeatureVector(featureAlphabet,
                (String[])targetAlphabet.toArray(new String[size]),testfeatureValues);
        //new instance(data,target,name,source)
        Instance testinstance = new Instance (testfeatureVector,targetAlphabet.lookupLabel("negative"), "xxx",null);
        System.out.print(maxent.predict(maxentclassifier, testinstance));
        //maxent.evaluate(maxentclassifier, "resource/testdata.txt");
    }
}

說明:trainingInstances為訓練樣本,testinstance為測試樣本,該程序的執行結果為“negative”。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM