不管是實驗室研究機器學習算法或是公司研發,都有需要自己改進算法的時候,下面就說說怎么在weka里增加改進的機器學習算法。
一 添加分類算法的流程
1 編寫的分類器必須繼承 Classifier或是Classifier的子類;下面用比較簡單的zeroR舉例說明;
2 復寫接口 buildClassifier,其是主要的方法之一,功能是構造分類器,訓練模型;
3 復寫接口 classifyInstance,功能是預測一個標簽的概率;或實現distributeForInstance,功能是對得到所有的概率分布;
4 復寫接口getCapabilities,其決定顯示哪個分類器,否則為灰色;
5 參數option的set/get方法;
6 globalInfo和seedTipText方法,功能是說明作用;
7 見 第二部分,把這個分類器增加到weka應用程序上;
zeroR.java源碼
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * ZeroR.java * Copyright (C) 1999 Eibe Frank * */ package weka.classifiers.rules; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import java.io.*; import java.util.*; import weka.core.*; /** * Class for building and using a 0-R classifier. Predicts the mean * (for a numeric class) or the mode (for a nominal class). * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.11 $ */ public class ZeroR extends Classifier implements WeightedInstancesHandler { /** The class value 0R predicts. */ private double m_ClassValue; /** The number of instances in each class (null if class numeric). */ private double [] m_Counts; /** The class attribute. */ private Attribute m_Class; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for building and using a 0-R classifier. Predicts the mean " + "(for a numeric class) or the mode (for a nominal class)."; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { double sumOfWeights = 0; m_Class = instances.classAttribute(); m_ClassValue = 0; switch (instances.classAttribute().type()) { case Attribute.NUMERIC: m_Counts = null; break; case Attribute.NOMINAL: m_Counts = new double [instances.numClasses()]; for (int i = 0; i < m_Counts.length; i++) { m_Counts[i] = 1; } sumOfWeights = instances.numClasses(); break; default: throw new Exception("ZeroR can only handle nominal and numeric class" + " attributes."); } Enumeration enu = instances.enumerateInstances(); while (enu.hasMoreElements()) { Instance instance = (Instance) enu.nextElement(); if (!instance.classIsMissing()) { if (instances.classAttribute().isNominal()) { m_Counts[(int)instance.classValue()] += instance.weight(); } else { m_ClassValue += instance.weight() * instance.classValue(); } sumOfWeights += instance.weight(); } } if (instances.classAttribute().isNumeric()) { if (Utils.gr(sumOfWeights, 0)) { m_ClassValue /= sumOfWeights; } } else { m_ClassValue = Utils.maxIndex(m_Counts); Utils.normalize(m_Counts, sumOfWeights); } } /** * Classifies a given instance. * * @param instance the instance to be classified * @return index of the predicted class */ public double classifyInstance(Instance instance) { return m_ClassValue; } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if class is numeric */ public double [] distributionForInstance(Instance instance) throws Exception { if (m_Counts == null) { double[] result = new double[1]; result[0] = m_ClassValue; return result; } else { return (double []) m_Counts.clone(); } } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { if (m_Class == null) { return "ZeroR: No model built yet."; } if (m_Counts == null) { return "ZeroR predicts class value: " + m_ClassValue; } else { return "ZeroR predicts class value: " + m_Class.value((int) m_ClassValue); } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new ZeroR(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } } }
二 添加模糊聚類算法流程
1.按照weka接口,寫好一個模糊聚類算法,源碼見最下面FuzzyCMeans.java ;並
2.把源碼拷貝到weka.clusterers路徑下;
3.修改 weka.gui.GenericObjectEditor.props ,在#Lists the Clusterers I want to choose from 的 weka.clusterers.Clusterer=\下加入:weka.clusterers.FuzzyCMeans
4. 相應的修改 weka.gui.GenericPropertiesCreator.props ,此去不用修改,因為包 weka.clusterers 已經存在,若加入新的包時則必須修改這里,加入新的包;
FuzzyCMeans.java源碼:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * FCM.java * Copyright (C) 2007 Wei Xiaofei * */ package weka.clusterers; import weka.classifiers.rules.DecisionTableHashKey; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.Capabilities.Capability; import weka.core.matrix.Matrix; import weka.filters.Filter; import weka.filters.unsupervised.attribute.ReplaceMissingValues; import java.util.Enumeration; import java.util.HashMap; import java.util.Random; import java.util.Vector; /** <!-- globalinfo-start --> * Cluster data using the Fuzzy C means algorithm * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N <num> * number of clusters. * (default 2).</pre> * * <pre> -F <num> * exponent. * (default 2).</pre> * * <pre> -S <num> * Random number seed. * (default 10)</pre> * <!-- options-end --> * * @author Wei Xiaofei * @version 1.03 * @see RandomizableClusterer */ public class FuzzyCMeans extends RandomizableClusterer implements NumberOfClustersRequestable, WeightedInstancesHandler { /** for serialization */ static final long serialVersionUID = -2134543132156464L; /** * replace missing values in training instances * 替換訓練集中的缺省值 */ private ReplaceMissingValues m_ReplaceMissingFilter; /** * number of clusters to generate * 產生聚類的個數 */ private int m_NumClusters = 2; /** * D: d(i,j)=||c(i)-x(j)||為第i個聚類中心與第j個數據點間的歐幾里德距離 */ private Matrix D; // private Matrix U; /** * holds the fuzzifier * 模糊算子(加權指數) */ private double m_fuzzifier = 2; /** * holds the cluster centroids * 聚類中心 */ private Instances m_ClusterCentroids; /** * Holds the standard deviations of the numeric attributes in each cluster * 每個聚類的標准差 */ private Instances m_ClusterStdDevs; /** * For each cluster, holds the frequency counts for the values of each * nominal attribute */ private int [][][] m_ClusterNominalCounts; /** * The number of instances in each cluster * 每個聚類包含的實例個數 */ private int [] m_ClusterSizes; /** * attribute min values * 屬性最小值 */ private double [] m_Min; /** * attribute max values * 屬性最大值 */ private double [] m_Max; /** * Keep track of the number of iterations completed before convergence * 迭代次數 */ private int m_Iterations = 0; /** * Holds the squared errors for all clusters * 平方誤差 */ private double [] m_squaredErrors; /** * the default constructor * 初始構造器 */ public FuzzyCMeans () { super(); m_SeedDefault = 10;//初始化種子個數 setSeed(m_SeedDefault); } /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui * 全局信息, 在圖形介面顯示 */ public String globalInfo() { return "Cluster data using the fuzzy k means algorithm"; } /** * Returns default capabilities of the clusterer. * * @return the capabilities of this clusterer * 聚類容器 */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); result.enable(Capability.NO_CLASS); // attributes result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); return result; } /** * Generates a clusterer. Has to initialize all fields of the clusterer * that are not being set via options. * * @param data set of instances serving as training data * @throws Exception if the clusterer has not been * generated successfully * 聚類產生函數 */ public void buildClusterer(Instances data) throws Exception { // can clusterer handle the data?檢測數據能否聚類 getCapabilities().testWithFail(data); m_Iterations = 0; m_ReplaceMissingFilter = new ReplaceMissingValues(); Instances instances = new Instances(data);//實例 instances.setClassIndex(-1); m_ReplaceMissingFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_ReplaceMissingFilter); m_Min = new double [instances.numAttributes()]; m_Max = new double [instances.numAttributes()]; for (int i = 0; i < instances.numAttributes(); i++) { m_Min[i] = m_Max[i] = Double.NaN;//隨機分配不定值 } m_ClusterCentroids = new Instances(instances, m_NumClusters);//聚類中心 int[] clusterAssignments = new int [instances.numInstances()]; for (int i = 0; i < instances.numInstances(); i++) { updateMinMax(instances.instance(i));//更新最大最小值 } Random RandomO = new Random(getSeed());//隨機數 int instIndex; HashMap initC = new HashMap(); DecisionTableHashKey hk = null; /* 利用決策表隨機生成聚類中心 */ for (int j = instances.numInstances() - 1; j >= 0; j--) { instIndex = RandomO.nextInt(j+1); hk = new DecisionTableHashKey(instances.instance(instIndex), instances.numAttributes(), true); if (!initC.containsKey(hk)) { m_ClusterCentroids.add(instances.instance(instIndex)); initC.put(hk, null); } instances.swap(j, instIndex); if (m_ClusterCentroids.numInstances() == m_NumClusters) { break; } } m_NumClusters = m_ClusterCentroids.numInstances();//聚類個數=聚類中心個數 D = new Matrix(solveD(instances).getArray());//求聚類中心到每個實例的距離 int i, j; int n = instances.numInstances(); Instances [] tempI = new Instances[m_NumClusters]; m_squaredErrors = new double [m_NumClusters]; m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0]; Matrix U = new Matrix(solveU(instances).getArray());//初始化隸屬矩陣U double q = 0;//初始化價值函數值 while (true) { m_Iterations++; for (i = 0; i < instances.numInstances(); i++) { Instance toCluster = instances.instance(i); int newC = clusterProcessedInstance(toCluster, true);//聚類處理實例,即輸入的實例應該聚到哪一個簇?! clusterAssignments[i] = newC; } // update centroids 更新聚類中心 m_ClusterCentroids = new Instances(instances, m_NumClusters); for (i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(instances, 0); } for (i = 0; i < instances.numInstances(); i++) { tempI[clusterAssignments[i]].add(instances.instance(i)); } for (i = 0; i < m_NumClusters; i++) { double[] vals = new double[instances.numAttributes()]; for (j = 0; j < instances.numAttributes(); j++) { double sum1 = 0, sum2 = 0; for (int k = 0; k < n; k++) { sum1 += U.get(i, k) * U.get(i, k) * instances.instance(k).value(j); sum2 += U.get(i, k) * U.get(i, k); } vals[j] = sum1 / sum2; } m_ClusterCentroids.add(new Instance(1.0, vals)); } D = new Matrix(solveD(instances).getArray()); U = new Matrix(solveU(instances).getArray());//計算新的聿屬矩陣U double q1 = 0;//新的價值函數值 for (i = 0; i < m_NumClusters; i++) { for (j = 0; j < n; j++) { /* 計算價值函數值 即q1 += U(i,j)^m * d(i,j)^2 */ q1 += Math.pow(U.get(i, j), getFuzzifier()) * D.get(i, j) * D.get(i, j); } } /* 上次價值函數值的改變量(q1 -q)小於某個閥值(這里用機器精度:2.2204e-16) */ if (q1 - q < 2.2204e-16) { break; } q = q1; } /* 計算標准差 跟K均值一樣 */ m_ClusterStdDevs = new Instances(instances, m_NumClusters); m_ClusterSizes = new int [m_NumClusters]; for (i = 0; i < m_NumClusters; i++) { double [] vals2 = new double[instances.numAttributes()]; for (j = 0; j < instances.numAttributes(); j++) { if (instances.attribute(j).isNumeric()) {//判斷屬性是否是數值型的?! vals2[j] = Math.sqrt(tempI[i].variance(j)); } else { vals2[j] = Instance.missingValue(); } } m_ClusterStdDevs.add(new Instance(1.0, vals2));//1.0代表權值, vals2代表屬性值 m_ClusterSizes[i] = tempI[i].numInstances(); } } /** * clusters an instance that has been through the filters * * @param instance the instance to assign a cluster to * @param updateErrors if true, update the within clusters sum of errors * @return a cluster number * 聚類一個實例, 返回實例應屬於哪一個簇的編號 * 首先計算輸入的實例到所有聚類中心的距離, 哪里距離最小 * 這個實例就屬於哪一個聚類中心所在簇 */ private int clusterProcessedInstance(Instance instance, boolean updateErrors) { double minDist = Integer.MAX_VALUE; int bestCluster = 0; for (int i = 0; i < m_NumClusters; i++) { double dist = distance(instance, m_ClusterCentroids.instance(i)); if (dist < minDist) { minDist = dist; bestCluster = i; } } if (updateErrors) { m_squaredErrors[bestCluster] += minDist; } return bestCluster; } /** * Classifies a given instance. * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an interger * if the class is enumerated, otherwise the predicted value * @throws Exception if instance could not be classified * successfully * 分類一個實例, 調用clusterProcessedInstance()函數 */ public int clusterInstance(Instance instance) throws Exception { m_ReplaceMissingFilter.input(instance); m_ReplaceMissingFilter.batchFinished(); Instance inst = m_ReplaceMissingFilter.output(); return clusterProcessedInstance(inst, false); } /** * 計算矩陣D, 即 d(i,j)=||c(i)-x(j)|| */ private Matrix solveD(Instances instances) { int n = instances.numInstances(); Matrix D = new Matrix(m_NumClusters, n); for (int i = 0; i < m_NumClusters; i++) { for (int j = 0; j < n; j++) { D.set(i, j, distance(instances.instance(j), m_ClusterCentroids.instance(i))); if (D.get(i, j) == 0) { D.set(i, j, 0.000000000001); } } } return D; } /** * 計算聿屬矩陣U, 即U(i,j) = 1 / sum(d(i,j)/ d(k,j))^(2/(m-1) */ private Matrix solveU(Instances instances) { int n = instances.numInstances(); int i, j; Matrix U = new Matrix(m_NumClusters, n); for (i = 0; i < m_NumClusters; i++) { for (j = 0; j < n; j++) { double sum = 0; for (int k = 0; k < m_NumClusters; k++) { //d(i,j)/d(k,j)^(2/(m-1) sum += Math.pow(D.get(i, j) / D.get(k, j), 2 /(getFuzzifier() - 1)); } U.set(i, j, Math.pow(sum, -1)); } } return U; } /** * Calculates the distance between two instances * * @param first the first instance * @param second the second instance * @return the distance between the two given instances * 計算兩個實例之間的距離, 返回歐幾里德距離 */ private double distance(Instance first, Instance second) { double val1; double val2; double dist = 0.0; for (int i = 0; i <first.numAttributes(); i++) { val1 = first.value(i); val2 = second.value(i); dist += (val1 - val2) * (val1 - val2); } dist = Math.sqrt(dist); return dist; } /** * Updates the minimum and maximum values for all the attributes * based on a new instance. * * @param instance the new instance * 更新所有屬性最大最小值, 跟K均值里的函數一樣 */ private void updateMinMax(Instance instance) { for (int j = 0;j < m_ClusterCentroids.numAttributes(); j++) { if (!instance.isMissing(j)) { if (Double.isNaN(m_Min[j])) { m_Min[j] = instance.value(j); m_Max[j] = instance.value(j); } else { if (instance.value(j) < m_Min[j]) { m_Min[j] = instance.value(j); } else { if (instance.value(j) > m_Max[j]) { m_Max[j] = instance.value(j); } } } } } } /** * Returns the number of clusters. * * @return the number of clusters generated for a training dataset. * @throws Exception if number of clusters could not be returned * successfully * 返回聚類個數 */ public int numberOfClusters() throws Exception { return m_NumClusters; } /** * 返回模糊算子, 即加權指數 * * @return 加權指數 * @throws Exception 加權指數不能成功返回 */ public double fuzzifier() throws Exception { return m_fuzzifier; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. * 返回一個枚舉描述的活動選項(菜單) */ public Enumeration listOptions () { Vector result = new Vector(); result.addElement(new Option( "\tnumber of clusters.\n" + "\t(default 2).", "N", 1, "-N <num>")); result.addElement(new Option( "\texponent.\n" + "\t(default 2.0).", "F", 1, "-F <num>")); Enumeration en = super.listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); return result.elements(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui * 返回文本信息 */ public String numClustersTipText() { return "set number of clusters"; } /** * set the number of clusters to generate * * @param n the number of clusters to generate * @throws Exception if number of clusters is negative * 設置聚類個數 */ public void setNumClusters(int n) throws Exception { if (n <= 0) { throw new Exception("Number of clusters must be > 0"); } m_NumClusters = n; } /** * gets the number of clusters to generate * * @return the number of clusters to generate * 取聚類個數 */ public int getNumClusters() { return m_NumClusters; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui * 返回文本信息 */ public String fuzzifierTipText() { return "set fuzzifier"; } /** * set the fuzzifier * * @param f fuzzifier * @throws Exception if exponent is negative * 設置模糊算子 */ public void setFuzzifier(double f) throws Exception { if (f <= 1) { throw new Exception("F must be > 1"); } m_fuzzifier= f; } /** * get the fuzzifier * * @return m_fuzzifier * 取得模糊算子 */ public double getFuzzifier() { return m_fuzzifier; } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N <num> * number of clusters. * (default 2).</pre> * * <pre> -F <num> * fuzzifier. * (default 2.0).</pre> * * <pre> -S <num> * Random number seed. * (default 10)</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported * 設置活動選項 */ public void setOptions (String[] options) throws Exception { String optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } optionString = Utils.getOption('F', options); if (optionString.length() != 0) { setFuzzifier((new Double(optionString)).doubleValue()); } super.setOptions(options); } /** * Gets the current settings of FuzzyCMeans * * @return an array of strings suitable for passing to setOptions() * 取得活動選項 */ public String[] getOptions () { int i; Vector result; String[] options; result = new Vector(); result.add("-N"); result.add("" + getNumClusters()); result.add("-F"); result.add("" + getFuzzifier()); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } /** * return a string describing this clusterer * * @return a description of the clusterer as a string * 結果顯示 */ public String toString() { int maxWidth = 0; for (int i = 0; i < m_NumClusters; i++) { for (int j = 0 ;j < m_ClusterCentroids.numAttributes(); j++) { if (m_ClusterCentroids.attribute(j).isNumeric()) { double width = Math.log(Math.abs(m_ClusterCentroids.instance(i).value(j))) / Math.log(10.0); width += 1.0; if ((int)width > maxWidth) { maxWidth = (int)width; } } } } StringBuffer temp = new StringBuffer(); String naString = "N/A"; for (int i = 0; i < maxWidth+2; i++) { naString += " "; } temp.append("\nFuzzy C-means\n======\n"); temp.append("\nNumber of iterations: " + m_Iterations+"\n"); temp.append("Within cluster sum of squared errors: " + Utils.sum(m_squaredErrors)); temp.append("\n\nCluster centroids:\n"); for (int i = 0; i < m_NumClusters; i++) { temp.append("\nCluster "+i+"\n\t"); temp.append("\n\tStd Devs: "); for (int j = 0; j < m_ClusterStdDevs.numAttributes(); j++) { if (m_ClusterStdDevs.attribute(j).isNumeric()) { temp.append(" "+Utils.doubleToString(m_ClusterStdDevs.instance(i).value(j), maxWidth+5, 4)); } else { temp.append(" "+naString); } } } temp.append("\n\n"); return temp.toString(); } /** * Gets the the cluster centroids * * @return the cluster centroids * 取得聚類中心 */ public Instances getClusterCentroids() { return m_ClusterCentroids; } /** * Gets the standard deviations of the numeric attributes in each cluster * * @return the standard deviations of the numeric attributes * in each cluster * 聚得標准差 */ public Instances getClusterStandardDevs() { return m_ClusterStdDevs; } /** * Returns for each cluster the frequency counts for the values of each * nominal attribute * * @return the counts */ public int [][][] getClusterNominalCounts() { return m_ClusterNominalCounts; } /** * Gets the squared error for all clusters * * @return the squared error * 取得平方差 */ public double getSquaredError() { return Utils.sum(m_squaredErrors); } /** * Gets the number of instances in each cluster * * @return The number of instances in each cluster * 取每個簇的實例個數 */ public int [] getClusterSizes() { return m_ClusterSizes; } /** * Main method for testing this class. * * @param argv should contain the following arguments: <p> * -t training file [-N number of clusters] * 主函數 */ public static void main (String[] argv) { runClusterer(new FuzzyCMeans (), argv); } }