K鄰近(k-Nearest Neighbor,KNN)分類算法是最簡單的機器學習算法了。它采用測量不同特征值之間的距離方法進行分類。它的思想很簡單:計算一個點A與其他所有點之間的距離,取出與該點最近的k個點,然后統計這k個點里面所屬分類比例最大的,則點A屬於該分類。
下面用一個例子來說明一下:
電影名稱 |
打斗次數 |
接吻次數 |
電影類型 |
California Man |
3 |
104 |
Romance |
He’s Not Really into Dudes |
2 |
100 |
Romance |
Beautiful Woman |
1 |
81 |
Romance |
Kevin Longblade |
101 |
10 |
Action |
Robo Slayer 3000 |
99 |
5 |
Action |
Amped II |
98 |
2 |
Action |
簡單說一下這個數據的意思:這里用打斗次數和接吻次數來界定電影類型,如上,接吻多的是Romance類型的,而打斗多的是動作電影。還有一部名字未知(這里名字未知是為了防止能從名字中猜出電影類型),打斗次數為18次,接吻次數為90次的電影,它到底屬於哪種類型的電影呢?
KNN算法要做的,就是先用打斗次數和接吻次數作為電影的坐標,然后計算其他六部電影與未知電影之間的距離,取得前K個距離最近的電影,然后統計這k個距離最近的電影里,屬於哪種類型的電影最多,比如Action最多,則說明未知的這部電影屬於動作片類型。
在實際使用中,有幾個問題是值得注意的:K值的選取,選多大合適呢?計算兩者間距離,用哪種距離會更好呢?計算量太大怎么辦?假設樣本中,類型分布非常不均,比如Action的電影有200部,但是Romance的電影只有20部,這樣計算起來,即使不是Action的電影,也會因為Action的樣本太多,導致k個最近鄰居里有不少Action的電影,這樣該怎么辦呢?
沒有萬能的算法,只有在一定使用環境中最優的算法。
1.1 算法指導思想
kNN算法的指導思想是“近朱者赤,近墨者黑”,由你的鄰居來推斷出你的類別。
先計算待分類樣本與已知類別的訓練樣本之間的距離,找到距離與待分類樣本數據最近的k個鄰居;再根據這些鄰居所屬的類別來判斷待分類樣本數據的類別。
1.2相似性度量
用空間內兩個點的距離來度量。距離越大,表示兩個點越不相似。距離的選擇有很多[13],通常用比較簡單的歐式距離。
歐式距離:
馬氏距離:馬氏距離能夠緩解由於屬性的線性組合帶來的距離失真,是數據的協方差矩陣。
曼哈頓距離:
切比雪夫距離:
閔氏距離:r取值為2時:曼哈頓距離;r取值為1時:歐式距離。
平均距離:
弦距離:
測地距離:
1.2 類別的判定
投票決定:少數服從多數,近鄰中哪個類別的點最多就分為該類。
加權投票法:根據距離的遠近,對近鄰的投票進行加權,距離越近則權重越大(權重為距離平方的倒數)
優缺點
1.2.1 優點
- 簡單,易於理解,易於實現,無需估計參數,無需訓練;
- 適合對稀有事件進行分類;
- 特別適合於多分類問題(multi-modal,對象具有多個類別標簽), kNN比SVM的表現要好。
- 懶惰算法,對測試樣本分類時的計算量大,內存開銷大,評分慢;
- 當樣本不平衡時,如一個類的樣本容量很大,而其他類樣本容量很小時,有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本占多數;
- 可解釋性較差,無法給出決策樹那樣的規則。
1.2.2 缺點
1.3 常見問題
1.3.1 k值的設定
k值選擇過小,得到的近鄰數過少,會降低分類精度,同時也會放大噪聲數據的干擾;而如果k值選擇過大,並且待分類樣本屬於訓練集中包含數據數較少的類,那么在選擇k個近鄰的時候,實際上並不相似的數據亦被包含進來,造成噪聲增加而導致分類效果的降低。
如何選取恰當的K值也成為KNN的研究熱點。k值通常是采用交叉檢驗來確定(以k=1為基准)。
經驗規則:k一般低於訓練樣本數的平方根。
1.3.2 類別的判定方式
投票法沒有考慮近鄰的距離的遠近,距離更近的近鄰也許更應該決定最終的分類,所以加權投票法更恰當一些。
1.3.3 距離度量方式的選擇
高維度對距離衡量的影響:眾所周知當變量數越多,歐式距離的區分能力就越差。
變量值域對距離的影響:值域越大的變量常常會在距離計算中占據主導作用,因此應先對變量進行標准化。
1.3.4 訓練樣本的參考原則
學者們對於訓練樣本的選擇進行研究,以達到減少計算的目的,這些算法大致可分為兩類。第一類,減少訓練集的大小。KNN算法存儲的樣本數據,這些樣本數據包含了大量冗余數據,這些冗余的數據增了存儲的開銷和計算代價。縮小訓練樣本的方法有:在原有的樣本中刪掉一部分與分類相關不大的樣本樣本,將剩下的樣本作為新的訓練樣本;或在原來的訓練樣本集中選取一些代表樣本作為新的訓練樣本;或通過聚類,將聚類所產生的中心點作為新的訓練樣本。
在訓練集中,有些樣本可能是更值得依賴的。可以給不同的樣本施加不同的權重,加強依賴樣本的權重,降低不可信賴樣本的影響。
1.3.5 性能問題
kNN是一種懶惰算法,而懶惰的后果:構造模型很簡單,但在對測試樣本分類地的系統開銷大,因為要掃描全部訓練樣本並計算距離。
已經有一些方法提高計算的效率,例如壓縮訓練樣本量等。
1.4 算法流程
- 准備數據,對數據進行預處理
- 選用合適的數據結構存儲訓練數據和測試元組
- 設定參數,如k
- 維護一個大小為k的的按距離由大到小的優先級隊列,用於存儲最近鄰訓練元組。隨機從訓練元組中選取k個元組作為初始的最近鄰元組,分別計算測試元組到這k個元組的距離,將訓練元組標號和距離存入優先級隊列
- 遍歷訓練元組集,計算當前訓練元組與測試元組的距離,將所得距離L 與優先級隊列中的最大距離Lmax
- 進行比較。若L>=Lmax,則舍棄該元組,遍歷下一個元組。若L < Lmax,刪除優先級隊列中最大距離的元
- 組,將當前訓練元組存入優先級隊列。
- 遍歷完畢,計算優先級隊列中k 個元組的多數類,並將其作為測試元組的類別。
9.測試元組集測試完畢后計算誤差率,繼續設定不同的k 值重新進行訓練,最后取誤差率最小的k 值。
Java代碼實現
public class KNN { /** * 設置優先級隊列的比較函數,距離越大,優先級越高 */ private Comparator<KNNNode> comparator =new Comparator<KNNNode>(){ public int compare(KNNNode o1, KNNNode o2) { if (o1.getDistance() >= o2.getDistance()) return -1; else return 1; } }; /** * 獲取K個不同的隨機數 * @param k 隨機數的個數 * @param max 隨機數最大的范圍 * @return 生成的隨機數數組 */ public List<Integer> getRandKNum(int k, int max) { List<Integer> rand = new ArrayList<Integer>(k); for (int i = 0; i < k; i++) { int temp = (int) (Math.random() * max); if (!rand.contains(temp)) rand.add(temp); else i--; } return rand; }
/* 計算測試元組與訓練元組之前的距離 * @param d1 測試元組 * @param d2 訓練元組 * @return 距離值 */ public double calDistance(List<Double> d1, List<Double> d2) { double distance = 0.00; for (int i = 0; i < d1.size(); i++) distance += (d1.get(i) - d2.get(i)) *(d1.get(i)-d2.get(i)); return distance; } /** * 執行KNN算法,獲取測試元組的類別 * @param datas 訓練數據集 * @param testData 測試元組 * @param k 設定的K值 * @return 測試元組的類別 */ public String knn(List<List<Double>> datas, List<Double> testData, int k) { PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode> (k,comparator); List<Integer> randNum = getRandKNum(k, datas.size()); for (int i = 0; i < k; i++) { int index = randNum.get(i); List<Double> currData = datas.get(index); String c = currData.get(currData.size() - 1).toString(); KNNNode node = new KNNNode(index, calDistance(testData, currData), c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List<Double> t = datas.get(i); double distance = calDistance(testData, t); KNNNode top = pq.peek(); if (top.getDistance() > distance) { pq.remove(); pq.add(new KNNNode(i, distance, t.get(t.size() - 1). toString())); } } return getMostClass(pq);
}
/** * 獲取所得到的k個最近鄰元組的多數類 * @param pq 存儲k個最近近鄰元組的優先級隊列 * @return 多數類的名稱 */ private String getMostClass(PriorityQueue<KNNNode> pq) { Map<String, Integer> classCount=new HashMap<String,Integer>(); int pqsize = pq.size(); for (int i = 0; i < pqsize; i++) { KNNNode node = pq.remove(); String c = node.getC(); if (classCount.containsKey(c)) classCount.put(c, classCount.get(c) + 1); else classCount.put(c, 1); } int maxIndex = -1; int maxCount = 0; Object[] classes = classCount.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCount.get(classes[i]) > maxCount) maxIndex = i; maxCount = classCount.get(classes[i]); } return classes[maxIndex].toString(); } }