機器學習的1NN最近鄰算法,在weka里叫IB1,是因為Instance Base 1 ,也就是只基於一個最近鄰的實例的惰性學習算法。
下面總結一下,weka中對IB1源碼的學習總結。
首先需要把 weka-src.jar 引入編譯路徑,否則無法跟蹤源碼。
1)讀取data數據,完成 IB1 分類器的調用,結果預測評估。為了后面的跟蹤。
try { File file = new File("F:\\tools/lib/data/contact-lenses.arff"); ArffLoader loader = new ArffLoader(); loader.setFile(file); ins = loader.getDataSet(); // 在使用樣本之前一定要首先設置instances的classIndex,否則在使用instances對象是會拋出異常 ins.setClassIndex(ins.numAttributes() - 1); cfs = new IB1(); cfs.buildClassifier(ins); Instance testInst; Evaluation testingEvaluation = new Evaluation(ins); int length = ins.numInstances(); for (int i = 0; i < length; i++) { testInst = ins.instance(i); // 通過這個方法來用每個測試樣本測試分類器的效果 double predictValue = cfs.classifyInstance(testInst); System.out.println(testInst.classValue()+"--"+predictValue); } // System.out.println("分類器的正確率:" + (1 - testingEvaluation.errorRate())); } catch (Exception e) { e.printStackTrace(); }
2)ctrl 點擊buildClassifier,進一步跟蹤buildClassifier方法的源碼,在IB1的類中重寫了這個抽象方法,源碼為:
public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_Train = new Instances(instances, 0, instances.numInstances()); m_MinArray = new double [m_Train.numAttributes()]; m_MaxArray = new double [m_Train.numAttributes()]; for (int i = 0; i < m_Train.numAttributes(); i++) { m_MinArray[i] = m_MaxArray[i] = Double.NaN; } Enumeration enu = m_Train.enumerateInstances(); while (enu.hasMoreElements()) { updateMinMax((Instance) enu.nextElement()); } }
(1)if是判斷,IB1分類器不能處理屬性是字符串和類別是數值型的樣本;
(2)if是判斷,刪除沒有類標簽的樣本;
(3)m_MinArray 和 m_MaxArray 分別保存最小和最大值,並且初始化double數組【樣本個數】;
(4)遍歷所有的訓練樣本實例,求最小和最大值;繼續跟蹤updateMinMax方法;
3)IB1類的updateMinMax方法的源碼如下:
private void updateMinMax(Instance instance) { for (int j = 0;j < m_Train.numAttributes(); j++) { if ((m_Train.attribute(j).isNumeric()) && (!instance.isMissing(j))) { if (Double.isNaN(m_MinArray[j])) { m_MinArray[j] = instance.value(j); m_MaxArray[j] = instance.value(j); } else { if (instance.value(j) < m_MinArray[j]) { m_MinArray[j] = instance.value(j); } else { if (instance.value(j) > m_MaxArray[j]) { m_MaxArray[j] = instance.value(j); } } } } } }
(1)過濾掉屬性不是數值型和缺失標簽的實例;
(2)若是isNaN,is not a number,是數值型的話,循環遍歷樣本的每一個屬性,求出最大最小值;
到此為止,訓練了IB1模型(有人可能會問lazy的算法難道不是不需要訓練模型嗎?我認為build分類器是為了初始化 m_Train和求所有實例的每個屬性的最大最小值,為了下一步求distance做准備)
下面介紹下預測源碼:
4)跟蹤classifyInstance方法,源碼如下:
public double classifyInstance(Instance instance) throws Exception { if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } double distance, minDistance = Double.MAX_VALUE, classValue = 0; updateMinMax(instance); Enumeration enu = m_Train.enumerateInstances(); while (enu.hasMoreElements()) { Instance trainInstance = (Instance) enu.nextElement(); if (!trainInstance.classIsMissing()) { distance = distance(instance, trainInstance); if (distance < minDistance) { minDistance = distance; classValue = trainInstance.classValue(); } } } return classValue; }
(1)調用方法updateMinMax更新了加入測試實例后的最大最小值;
(2)計算測試實例到每一個訓練實例的距離,distance方法,並且保存距離最小的實例minDistance;
5)跟蹤classifyInstance方法,源碼如下:
private double distance(Instance first, Instance second) { double diff, distance = 0; for(int i = 0; i < m_Train.numAttributes(); i++) { if (i == m_Train.classIndex()) { continue; } if (m_Train.attribute(i).isNominal()) { // If attribute is nominal if (first.isMissing(i) || second.isMissing(i) || ((int)first.value(i) != (int)second.value(i))) { distance += 1; } } else { // If attribute is numeric if (first.isMissing(i) || second.isMissing(i)){ if (first.isMissing(i) && second.isMissing(i)) { diff = 1; } else { if (second.isMissing(i)) { diff = norm(first.value(i), i); } else { diff = norm(second.value(i), i); } if (diff < 0.5) { diff = 1.0 - diff; } } } else { diff = norm(first.value(i), i) - norm(second.value(i), i); } distance += diff * diff; } } return distance; }
對每一個屬性遍歷,計算數值屬性距離的平方和,norm方法為規范化距離公式,為【0,1】的實數
6)跟蹤norm規范化方法,源碼如下:
private double norm(double x,int i) { if (Double.isNaN(m_MinArray[i]) || Utils.eq(m_MaxArray[i], m_MinArray[i])) { return 0; } else { return (x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]); } }
規范化距離:(x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]);
具體的算法偽代碼,請查找最近鄰分類器的論文,我就不貼出來了。