KNN鄰近分類算法


K鄰近(k-Nearest Neighbor,KNN)分類算法是最簡單的機器學習算法了。它采用測量不同特征值之間的距離方法進行分類。它的思想很簡單:計算一個點A與其他所有點之間的距離,取出與該點最近的k個點,然后統計這k個點里面所屬分類比例最大的,則點A屬於該分類。

下面用一個例子來說明一下:

 

電影名稱

打斗次數

接吻次數

電影類型

California Man

3

104

Romance

He’s Not Really into Dudes

2

100

Romance

Beautiful Woman

1

81

Romance

Kevin Longblade

101

10

Action

Robo Slayer 3000

99

5

Action

Amped II

98

2

Action

簡單說一下這個數據的意思:這里用打斗次數和接吻次數來界定電影類型,如上,接吻多的是Romance類型的,而打斗多的是動作電影。還有一部名字未知(這里名字未知是為了防止能從名字中猜出電影類型),打斗次數為18次,接吻次數為90次的電影,它到底屬於哪種類型的電影呢?

KNN算法要做的,就是先用打斗次數和接吻次數作為電影的坐標,然后計算其他六部電影與未知電影之間的距離,取得前K個距離最近的電影,然后統計這k個距離最近的電影里,屬於哪種類型的電影最多,比如Action最多,則說明未知的這部電影屬於動作片類型。

在實際使用中,有幾個問題是值得注意的:K值的選取,選多大合適呢?計算兩者間距離,用哪種距離會更好呢?計算量太大怎么辦?假設樣本中,類型分布非常不均,比如Action的電影有200部,但是Romance的電影只有20部,這樣計算起來,即使不是Action的電影,也會因為Action的樣本太多,導致k個最近鄰居里有不少Action的電影,這樣該怎么辦呢?

沒有萬能的算法,只有在一定使用環境中最優的算法。

1.1 算法指導思想

kNN算法的指導思想是“近朱者赤,近墨者黑”,由你的鄰居來推斷出你的類別。

先計算待分類樣本與已知類別的訓練樣本之間的距離,找到距離與待分類樣本數據最近的k個鄰居;再根據這些鄰居所屬的類別來判斷待分類樣本數據的類別。

 

1.2相似性度量

用空間內兩個點的距離來度量。距離越大,表示兩個點越不相似。距離的選擇有很多[13],通常用比較簡單的歐式距離。

歐式距離

 

馬氏距離:馬氏距離能夠緩解由於屬性的線性組合帶來的距離失真,是數據的協方差矩陣。

 

曼哈頓距離

 

切比雪夫距離

 

閔氏距離:r取值為2時:曼哈頓距離;r取值為1時:歐式距離。

 

 

平均距離

 

弦距離

 

測地距離

 

 

1.2 類別的判定

投票決定:少數服從多數,近鄰中哪個類別的點最多就分為該類。

加權投票法:根據距離的遠近,對近鄰的投票進行加權,距離越近則權重越大(權重為距離平方的倒數)

 優缺點

1.2.1              優點

  1. 簡單,易於理解,易於實現,無需估計參數,無需訓練;
  2. 適合對稀有事件進行分類;
  3. 特別適合於多分類問題(multi-modal,對象具有多個類別標簽), kNN比SVM的表現要好。
  4. 懶惰算法,對測試樣本分類時的計算量大,內存開銷大,評分慢;
  5. 當樣本不平衡時,如一個類的樣本容量很大,而其他類樣本容量很小時,有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本占多數;
  6. 可解釋性較差,無法給出決策樹那樣的規則。

1.2.2              缺點

1.3 常見問題

1.3.1              k值的設定

k值選擇過小,得到的近鄰數過少,會降低分類精度,同時也會放大噪聲數據的干擾;而如果k值選擇過大,並且待分類樣本屬於訓練集中包含數據數較少的類,那么在選擇k個近鄰的時候,實際上並不相似的數據亦被包含進來,造成噪聲增加而導致分類效果的降低。

如何選取恰當的K值也成為KNN的研究熱點。k值通常是采用交叉檢驗來確定(以k=1為基准)。

經驗規則:k一般低於訓練樣本數的平方根。

1.3.2              類別的判定方式

投票法沒有考慮近鄰的距離的遠近,距離更近的近鄰也許更應該決定最終的分類,所以加權投票法更恰當一些。

1.3.3              距離度量方式的選擇

高維度對距離衡量的影響:眾所周知當變量數越多,歐式距離的區分能力就越差。

變量值域對距離的影響:值域越大的變量常常會在距離計算中占據主導作用,因此應先對變量進行標准化。

1.3.4              訓練樣本的參考原則

學者們對於訓練樣本的選擇進行研究,以達到減少計算的目的,這些算法大致可分為兩類。第一類,減少訓練集的大小。KNN算法存儲的樣本數據,這些樣本數據包含了大量冗余數據,這些冗余的數據增了存儲的開銷和計算代價。縮小訓練樣本的方法有:在原有的樣本中刪掉一部分與分類相關不大的樣本樣本,將剩下的樣本作為新的訓練樣本;或在原來的訓練樣本集中選取一些代表樣本作為新的訓練樣本;或通過聚類,將聚類所產生的中心點作為新的訓練樣本。

在訓練集中,有些樣本可能是更值得依賴的。可以給不同的樣本施加不同的權重,加強依賴樣本的權重,降低不可信賴樣本的影響。

1.3.5              性能問題

kNN是一種懶惰算法,而懶惰的后果:構造模型很簡單,但在對測試樣本分類地的系統開銷大,因為要掃描全部訓練樣本並計算距離。

已經有一些方法提高計算的效率,例如壓縮訓練樣本量等。

1.4 算法流程

  1. 准備數據,對數據進行預處理
  2. 選用合適的數據結構存儲訓練數據和測試元組
  3. 設定參數,如k
  4. 維護一個大小為k的的按距離由大到小的優先級隊列,用於存儲最近鄰訓練元組。隨機從訓練元組中選取k個元組作為初始的最近鄰元組,分別計算測試元組到這k個元組的距離,將訓練元組標號和距離存入優先級隊列
  5. 遍歷訓練元組集,計算當前訓練元組與測試元組的距離,將所得距離L 與優先級隊列中的最大距離Lmax
  6. 進行比較。若L>=Lmax,則舍棄該元組,遍歷下一個元組。若L < Lmax,刪除優先級隊列中最大距離的元
  7. 組,將當前訓練元組存入優先級隊列。
  8. 遍歷完畢,計算優先級隊列中k 個元組的多數類,並將其作為測試元組的類別。

      9.測試元組集測試完畢后計算誤差率,繼續設定不同的k 值重新進行訓練,最后取誤差率最小的k 值。

Java代碼實現

public class KNN {
    /**
     * 設置優先級隊列的比較函數,距離越大,優先級越高
     */
    private Comparator<KNNNode> comparator =new Comparator<KNNNode>(){
        public int compare(KNNNode o1, KNNNode o2) {
            if (o1.getDistance() >= o2.getDistance())
                return -1;
            else
                return 1;
        }
    };
    /**
     * 獲取K個不同的隨機數
     * @param k 隨機數的個數
     * @param max 隨機數最大的范圍
     * @return 生成的隨機數數組
     */
    public List<Integer> getRandKNum(int k, int max) {
        List<Integer> rand = new ArrayList<Integer>(k);
        for (int i = 0; i < k; i++) {
            int temp = (int) (Math.random() * max);
            if (!rand.contains(temp))
                rand.add(temp);
            else
                i--;
        }
        return rand;
    }
 /* 計算測試元組與訓練元組之前的距離
     * @param d1 測試元組
     * @param d2 訓練元組
     * @return 距離值
     */
    public double calDistance(List<Double> d1, List<Double> d2) {
        double distance = 0.00;
        for (int i = 0; i < d1.size(); i++)
            distance += (d1.get(i) - d2.get(i)) *(d1.get(i)-d2.get(i));
        return distance;
    }
    
    /**
     * 執行KNN算法,獲取測試元組的類別
     * @param datas 訓練數據集
     * @param testData 測試元組
     * @param k 設定的K值
     * @return 測試元組的類別
     */
    public String knn(List<List<Double>> datas, List<Double> testData,             int k) {
        PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>                                                     (k,comparator);
        List<Integer> randNum = getRandKNum(k, datas.size());
        for (int i = 0; i < k; i++) {
            int index = randNum.get(i);
            List<Double> currData = datas.get(index);
            String c = currData.get(currData.size() - 1).toString();
            KNNNode node = new KNNNode(index, calDistance(testData,                                         currData), c);
            pq.add(node);
        }
        for (int i = 0; i < datas.size(); i++) {
            List<Double> t = datas.get(i);
            double distance = calDistance(testData, t);
            KNNNode top = pq.peek();
            if (top.getDistance() > distance) {
                pq.remove();
                pq.add(new KNNNode(i, distance, t.get(t.size() - 1).                                     toString()));
            }
        }
        return getMostClass(pq);
}

 

/**
     * 獲取所得到的k個最近鄰元組的多數類
     * @param pq 存儲k個最近近鄰元組的優先級隊列
     * @return 多數類的名稱
     */
    private String getMostClass(PriorityQueue<KNNNode> pq) {
        Map<String, Integer> classCount=new HashMap<String,Integer>();
        int pqsize = pq.size();
        for (int i = 0; i < pqsize; i++) {
            KNNNode node = pq.remove();
            String c = node.getC();
            if (classCount.containsKey(c))
                classCount.put(c, classCount.get(c) + 1);
            else
                classCount.put(c, 1);
        }
        int maxIndex = -1;
        int maxCount = 0;
        Object[] classes = classCount.keySet().toArray(); 
        for (int i = 0; i < classes.length; i++) { 
            if (classCount.get(classes[i]) > maxCount)
                maxIndex = i; maxCount = classCount.get(classes[i]);
        } 
        return classes[maxIndex].toString();
    }
}

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM