KNN算法
一、KNN算法概述
KNN是Machine Learning領域一個簡單又實用的算法,與之前討論過的算法主要存在兩點不同:
- 它是一種非參方法。即不必像線性回歸、邏輯回歸等算法一樣有固定格式的模型,也不需要去擬合參數。
- 它既可用於分類,又可應用於回歸。
KNN的基本思想有點類似“物以類聚,人以群分”,打個通俗的比方就是“如果你要了解一個人,可以從他最親近的幾個朋友去推測他是什么樣的人”。
在分類領域,對於一個未知點,選取K個距離(可以是歐氏距離,也可以是其他相似度度量指標)最近的點,然后統計這K個點,在這K個點中頻數最多的那一類就作為分類結果。比如下圖,若令K=4,則?處應分成紅色三角形;若令K=6,則?處應分類藍色正方形。
在回歸(簡單起見,這里討論一元回歸)領域,如果只知道某點的預測變量,要回歸響應變量
,只需要在橫坐標軸上(因為不知道縱坐標的值,所以沒法計算歐氏距離)選取K個最近的點,然后平均(也可以加權平均)這些點的響應值,作為該點的響應值即可。比如下圖中,已知前5個點的橫縱坐標值,求
時,
為多少?若令K=2,則距6.5最近的2個點是(5.1, 8)和(4, 27),把這兩個點的縱坐標平均值17.5就可以當作回歸結果,認為
KNN具體的算法步驟可參考延伸閱讀文獻1。
二、KNN性能討論
KNN的基本思想與計算過程很簡單,你只需要考慮兩件事:
- K預設值取多少?
- 如何定義距離?
其中如何定義距離這個需要結合具體的業務應用背景,本文不細致討論,距離計算方法可參看延伸閱讀文獻2。這里只討論K取值時對算法性能的影響。
在上圖中,紫色虛線是貝葉斯決策邊界線,也是最理想的分類邊界,黑色實線是KNN的分類邊界。
可以發現:K越小,分類邊界曲線越光滑,偏差越小,方差越大;K越大,分類邊界曲線越平坦,偏差越大,方差越小。
所以即使簡單如KNN,同樣要考慮偏差和方差的權衡問題,表現為K的選取。
KNN的優點就是簡單直觀,無需擬合參數,在樣本本身區分度較高的時候效果會很不錯;但缺點是當樣本量大的時候,找出K個最鄰近點的計算代價會很大,會導致算法很慢,此外KNN的可解釋性較差。
KNN的一些其他問題的思考可參看延伸閱讀文獻3。
三、實戰案例
1、KNN在保險業中挖掘潛在用戶的應用
這里應用ISLR
包里的Caravan
數據集,先大致瀏覽一下:
> library(ISLR) > str(Caravan) 'data.frame': 5822 obs. of 86 variables: $ Purchase: Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ... > table(Caravan$Purchase)/sum(as.numeric(table(Caravan$Purchase))) No Yes 0.94022673 0.05977327
5822行觀測,86個變量,其中只有Purchase
是分類型變量,其他全是數值型變量。Purchase
兩個水平,No
和Yes
分別表示不買或買保險。可見到有約6%的人買了保險。
由於KNN算法要計算距離,這85個數值型變量量綱不同,相同兩個點在不同特征變量上的距離差值可能非常大。因此要歸一化,這是Machine Learning的常識。這里直接用scale()
函數將各連續型變量進行正態標准化,即轉化為服從均值為0,標准差為1的正態分布。
> standardized.X=scale(Caravan[,-86]) > mean(standardized.X[,sample(1:85,1)]) [1] -2.047306e-18 > var(standardized.X[,sample(1:85,1)]) [1] 1 > mean(standardized.X[,sample(1:85,1)]) [1] 1.182732e-17 > var(standardized.X[,sample(1:85,1)]) [1] 1 > mean(standardized.X[,sample(1:85,1)]) [1] -3.331466e-17 > var(standardized.X[,sample(1:85,1)]) [1] 1
可見隨機抽取一個標准化后的變量,基本都是均值約為0,標准差為1。
> #前1000觀測作為測試集,其他當訓練集 > test <- 1:1000 > train.X <- standardized.X[-test,] > test.X <- standardized.X[test,] > train.Y <- Caravan$Purchase[-test] > test.Y <- Caravan$Purchase[test] > knn.pred <- knn(train.X,test.X,train.Y,k=) > mean(test.Y!=knn.pred) [1] 0.117 > mean(test.Y!="No") [1] 0.059
當K=1時,KNN總體的分類結果在測試集上的錯誤率約為12%。由於大部分的人都不買保險(先驗概率只有6%),那么如果模型預測不買保險的准確率應當很高,糾結於預測不買保險實際上卻買保險的樣本沒有意義,同樣的也不必考慮整體的准確率(Accuracy)。作為保險銷售人員,只需要關心在模型預測下會買保險的人中有多少真正會買保險,這是精准營銷的精確度(Precision);因此,在這樣的業務背景中,應該着重分析模型的Precesion,而不是Accuracy。
> table(knn.pred,test.Y)
test.Y
knn.pred No Yes
No 874 50 Yes 67 9 > 9/(67+9) [1] 0.1184211
可見K=1時,KNN模型的Precision約為12%,是隨機猜測概率(6%)的兩倍!
下面嘗試K取不同的值:
> knn.pred <- knn(train.X,test.X,train.Y,k=3) > table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2] Yes 0.2 > knn.pred <- knn(train.X,test.X,train.Y,k=5) > table(knn.pred,test.Y)[2,2]/rowSums(table(knn.pred,test.Y))[2] Yes 0.2666667
可以發現當K=3時,Precision=20%;當K=5時,Precision=26.7%。
作為對比,這個案例再用邏輯回歸做一次!
> glm.fit <- glm(Purchase~.,data=Caravan,family = binomial,subset = -test)
Warning message:
glm.fit:擬合機率算出來是數值零或一
> glm.probs <- predict(glm.fit,Caravan[test,],type = "response") > glm.pred <- ifelse(glm.probs >0.5,"Yes","No") > table(glm.pred,test.Y) test.Y glm.pred No Yes No 934 59 Yes 7 0
這個分類效果就差很多,Precision竟然是0!事實上,分類概率閾值為0.5是針對等可能事件,但買不買保險顯然不是等可能事件,把閾值降低到0.25再看看:
> glm.pred <- ifelse(glm.probs >0.25,"Yes","No") > table(glm.pred,test.Y) test.Y glm.pred No Yes No 919 48 Yes 22 11
這下子Precision就達到1/3了,比隨機猜測的精確度高出5倍不止!
以上試驗都充分表明,通過機器學習算法進行精准營銷的精確度比隨機猜測的效果要強好幾倍!
2、KNN回歸
在R中,KNN分類函數是knn()
,KNN回歸函數是knnreg()
。
> #加載數據集BloodBrain,用到向量logBBB和數據框bbbDescr > library(caret) > data(BloodBrain) > class(logBBB) [1] "numeric" > dim(bbbDescr) [1] 208 134 > #取約80%的觀測作訓練集。 > inTrain <- createDataPartition(logBBB, p = .8)[[1]] > trainX <- bbbDescr[inTrain,] > trainY <- logBBB[inTrain] > testX <- bbbDescr[-inTrain,] > testY <- logBBB[-inTrain] > #構建KNN回歸模型 > fit <- knnreg(trainX, trainY, k = 3) > fit 3-nearest neighbor regression model > #KNN回歸模型預測測試集 > pred <- predict(fit, testX) > #計算回歸模型的MSE > mean((pred-testY)^2) [1] 0.5821147
這個KNN回歸模型的MSE只有0.58,可見回歸效果很不錯,偏差很小!下面用可視化圖形比較一下結果。
> #將訓練集、測試集和預測值結果集中比較 > df <-data.frame(class=c(rep("trainY",length(trainY)),rep("testY",length(testY)),rep("predY",length(pred))),Yval=c(trainY,testY,pred)) > ggplot(data=df,mapping = aes(x=Yval,fill=class))+ + geom_dotplot(alpha=0.8)
這是dotplot,橫坐標才是響應變量的值,縱坐標表頻率。比較相鄰的紅色點和綠色點在橫軸上的差異,即表明測試集中預測值與實際值的差距。
> #比較測試集的預測值和實際值 > df2 <- data.frame(testY,pred) > ggplot(data=df2,mapping = aes(x=testY,y=pred))+ + geom_point(color="steelblue",size=3)+ + geom_abline(slope = 1,size=1.5,linetype=2)
這張散點圖則直接將測試集中的實際值和預測值進行對比,虛線是。點離這條虛線越近,表明預測值和實際值之間的差異就越小。
參考文獻
Gareth James et al. An Introduction to Statistical Learning.
Wikipedia. k-nearest neighbors algorithm.
KNN for Smoothing and Prediction.