KNN算法的實現(R語言)


一 . K-近鄰算法(KNN)概述 

    最簡單最初級的分類器是將全部的訓練數據所對應的類別都記錄下來,當測試對象的屬性和某個訓練對象的屬性完全匹配時,便可以對其進行分類。但是怎么可能所有測試對象都會找到與之完全匹配的訓練對象呢,其次就是存在一個測試對象同時與多個訓練對象匹配,導致一個訓練對象被分到了多個類的問題,基於這些問題呢,就產生了KNN。

  KNN是通過測量不同特征值之間的距離進行分類。它的思路是:如果一個樣本在特征空間中的k個最相似(即特征空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,其中K通常是不大於20的整數。KNN算法中,所選擇的鄰居都是已經正確分類的對象。該方法在定類決策上只依據最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。

  下面通過一個簡單的例子說明一下:如下圖,綠色圓要被決定賦予哪個類,是紅色三角形還是藍色四方形?如果K=3,由於紅色三角形所占比例為2/3,綠色圓將被賦予紅色三角形那個類,如果K=5,由於藍色四方形比例為3/5,因此綠色圓被賦予藍色四方形類。由此也說明了KNN算法的結果很大程度取決於K的選擇。

在KNN中,通過計算對象間距離來作為各個對象之間的非相似性指標,避免了對象之間的匹配問題,在這里距離一般使用歐氏距離或曼哈頓距離。同時,KNN通過依據k個對象中占優的類別進行決策,而不是單一的對象類別決策。這兩點就是KNN算法的優勢。

     接下來對KNN算法的思想總結一下:就是在訓練集中數據和標簽已知的情況下,輸入測試數據,將測試數據的特征與訓練集中對應的特征進行相互比較,找到訓練集中與之最為相似的前K個數據,則該測試數據對應的類別就是K個數據中出現次數最多的那個分類,其算法的描述為:

1)計算測試數據與各個訓練數據之間的距離;

2)按照距離的遞增關系進行排序;

3)選取距離最小的K個點;

4)確定前K個點所在類別的出現頻率;

5)返回前K個點中出現頻率最高的類別作為測試數據的預測分類。

 

代碼實現:

# KNN Regression function 
knn <- function(train.data, train.label, test.data, K=3, distance = 'euclidean'){
    ## count number of train samples
    train.len <- nrow(train.data)
  
    ## count number of test samples
    test.len <- nrow(test.data)
    
    ## New List for hold the test label (the length is the same with the length of test data)
    test.label <- rep(0,test.len)
    ## calculate distances between samples
    dist <- as.matrix(dist(rbind(test.data, train.data), method= distance))[1:test.len, (test.len+1):(test.len+train.len)]
  
    ## for each test sample
    for (i in 1:test.len){
        ### find its K nearest neighbours from training sampels...
        nn <- as.data.frame(sort(dist[i,], index.return = TRUE))[1:K,2]
    
        ### and calculate the predicted labels according to the average of the neighbors’ values
        test.label[i]<-mean(train.label[nn])
    }
  
    ## return the predict values
    return (test.label)
}

## Load Library
library(ggplot2)
library(reshape2)

## Load Dataset
Task1A_train <- read.csv("Task1A_train.csv") 
Task1A_test <- read.csv("Task1A_test.csv")

這里,我們計算error function,直接就是求MSE

距離的話 我們直接用幾何距離

## Create train and test data
train.data <- as.matrix(Task1A_train[,1])
train.label <- as.matrix(Task1A_train[,-1])
test.data <- as.matrix(Task1A_test[,1])
test.label <- as.matrix(Task1A_test[,-1])

# KNN Regression function 
knn <- function(train.data, train.label, test.data, K=3, distance = 'euclidean'){
    ## count number of train samples
    train.len <- nrow(train.data)
  
    ## count number of test samples
    test.len <- nrow(test.data)
    
    ## New List for hold the test label (the length is the same with the length of test data)
    test.label <- rep(0,test.len)
    ## calculate distances between samples
    dist <- as.matrix(dist(rbind(test.data, train.data), method= distance))[1:test.len, (test.len+1):(test.len+train.len)]
  
    ## for each test sample
    for (i in 1:test.len){
        ### find its K nearest neighbours from training sampels...
        nn <- as.data.frame(sort(dist[i,], index.return = TRUE))[1:K,2]
    
        ### and calculate the predicted labels according to the average of the neighbors’ values
        test.label[i]<-mean(train.label[nn])
    }
  
    ## return the predict values
    return (test.label)
}



# calculate the regression error for K in 1:20 
# Here we use Mean Square Error (MSE) 
miss <- data.frame('K'=1:20, 'train'=rep(0,20), 'test'=rep(0,20)) # New data frame to store the error value
for (k in 1:20){
    miss[k,'train'] <- sum((knn(train.data, train.label, train.data, K=k) - train.label) ^ 2)/nrow(train.data)
    miss[k,'test'] <-  sum((knn(train.data, train.label, test.data, K=k)  - test.label) ^ 2)/nrow(test.data)
}

我們這里采用的是K 從1 到20 

那門我們把產生的error給plot出來,大概判斷下,哪個K是最小的

## Plot the training and the testing errors versus 1/K for K=1,..,20
miss.m <- melt(miss, id='K') # reshape for visualization
names(miss.m) <- c('K', 'type', 'error')
ggplot(data=miss.m, aes(x=log(1/K), y=error, color=type)) + geom_line() +
       geom_point() +
       scale_color_discrete(guide = guide_legend(title = NULL)) + theme_minimal() +
       ggtitle("KNN Regression Error")

這里我們可以發現 在K= 11的時候,此時針對這個數據集,KNN的在error處理上效果更好。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM