knn原理與實踐


knn是一種基本分類與回歸方法

應用:knn算法不僅可以用於分類,還可以用於回歸..

1、文本分類:文本分類主要應用於信息檢索,機器翻譯,自動文摘,信息過濾,郵件分類等任務.

2、可以使用knn算法做到比較通用的現有用戶產品推薦,基於用戶的最近鄰(長得最像的用戶)買了什么產品來推薦是種介於電子商務網站和sns網站之間的精確營銷.只需要定期(例如每月)維護更新最近鄰表就可以,基於最近鄰表做搜索推薦可以很實時

優點:

1、簡單,易於理解,易於實現,無需估計參數,無需訓練,選擇合適的k,對異常值不敏感;

2、適合於多分類問題(multi-modal,對象具有多個類別標簽)

3、可拓展性強,添加新實例無需重新構造模型

缺點:

1、當樣本不平衡時,如一個類的樣本容量很大,而其他類樣本容量很小時,有可能導致當輸入一個新樣本時,該樣本的K個鄰居中大容量類的樣本占多數.可以采用權值的方法(和該樣本距離小的鄰居權值大)來改進.

2、計算量較大,因為對每一個待分類的文本都要計算它到全體已知樣本的距離,才能求得它的K個最近鄰點.對於大數據分類,即需要大量的訓練樣本,計算復雜度高

3、可理解性差,無法給出像決策樹那樣的規則.

 

距離度量:

1、高維度對距離衡量的影響:眾所周知當變量數越多,歐式距離的區分能力就越差.

2、變量值域對距離的影響:值域越大的變量常常會在距離計算中占據主導作用,因此應先對變量進行標准化.

 

K值選擇:目前采用交叉驗證方式,選出誤差率最小的對應的K

 

構造:

1、隨着樹的深度增加,循環的選取坐標軸,作為分割超平面的法向量。對於3-d tree來說,根節點選取x軸,根節點的孩子選取y軸,根節點的孫子選取z軸,根節點的曾孫子選取x軸,這樣循環下去。

2、每次均為所有對應實例的中位數的實例作為切分點,切分點作為父節點,左右兩側為划分的作為左右兩子樹。

對於n個實例的k維數據來說,建立kd-tree的時間復雜度為O(k*n*logn)。

搜索:

最近鄰搜索如下(k最近鄰,搜索k次,每次將上一次最近鄰刪除)

1、首先從根節點出發找到包含目標點的葉節點,目標點的最近鄰一定在以目標點為中心,並通過當前葉節點的超球體內部,

2、然后從該葉節點出發,依次回退到父節點,

3、如果父節點的另一子節點的區域與超球體相交,則到該區域繼續查找,不斷的查找與目標點最近鄰的節點,直到不能查找最近鄰的節點為止。

 

############################R語言#########################

library(class)

knn(train,test,cl,k=1,l=0,prob=FALSE,use.all=TRUE)

################################案例###########################################

##############################案例#############################################

library(class)

library(nutshell) ######取數據集spambase做案例#########

library(sampling) ########用抽樣函數strata做抽樣###################

data(spambase)

spambase.strata<-strata(spambase,stratanames=c("is_spam"),size=c(1269,1951)

,method="srswor") ########變量ID_unit#描述了樣本中的行號信息###########

spambase.training<-spambase[rownames(spambase)%in%spambase.strata$ID_unit,]

#####訓練集#############

spambase.validation<-spambase[!(rownames(spambase)%in%spambase.strata$ID_unit),]

######驗證集###############

spambase.knn<-knn(train=spambase.training,test=spambase.validation,

cl=spambase.training$is_spam)

##########cl:訓練數據的響應變量(因子類型)######################

summary(spambase.knn)

table(predicted=spambase.knn,actual=spambase.validation$is_spam)

 

 

####################matlab代碼:包含分類與回歸#######################

functionrelustLabel=KNN(test,train,trainlabels,k,type) %% test 為一條輸入測試數據,train為樣本數據,trainlabels為樣本標簽,選取k個臨近值 

    row = size(train,1);

    for j=1:row

        switch type 

            case 1  % 求test到每個樣本的歐氏距離 

                distanceMat(j)=sum((test-train(j,:)).^2);

            case 2  %求test到每個樣本的夾角余弦               

                distanceMat(j)=(train(j,:)*test')/(norm(train(j,:),2)*norm(test,2)); 

                if distanceMat(j)<0 

                    distanceMat(j)=(distanceMat(j)+1)/2;

                end

        end

    end

    distanceMat=distanceMat'; 

    [B, IX] = sort(distanceMat,'ascend');  %距離從小到大排序

    len = min(k,length(B));  %選k個鄰近值,當然k不能超過訓練樣本個數 

    relustLabel = mode(trainlabels(IX(1:len))); % 取眾數(即出現頻率最高的label)作為返回結果

    %%%%%%%%%%%%%%%%%對於回歸而言: relustLabel = avg(trainlabels(IX(1:len)))

end

 

 

 %主程序:

loaddata; 

dataMat = data(:,1:3);

labels = data(:,4);

len = size(dataMat,1);

k = 4; 

error = 0;

%觀察可視化數據 

label1=find(data(:,4)==1);

label2=find(data(:,4)==2);

label3=find(data(:,4)==3); 

plot3(data(label1,1),data(label1,2),data(label1,3),'ro'); 

hold on 

plot3(data(label2,1),data(label2,2),data(label2,3),'go');

plot3(data(label3,1),data(label3,2),data(label3,3),'bo'); 

grid on  %歸一化處理 

maxV = max(dataMat);

minV = min(dataMat);

range = maxV-minV;

newdataMat =  (dataMat-repmat(minV,[len,1]))./(repmat(range,[len,1])); 

%測試數據比例

Ratio = 0.1; 

numTest = Ratio * len; % 100條測試, 900條訓練

 

%訓練數據和測試數據 

TrainData=newdataMat(numTest+1:end,:);

TrainLabels=labels(numTest+1:end,:);

TestData=newdataMat(1:numTest,:);

TestLabels=labels(1:numTest,:); %測試,歐氏距離type=1, 夾角余弦type=2 

type=1; 

for i = 1:numTest

    classifyresult =  KNN(TestData(i,:),TrainData,TrainLabels,k,type);

    % fprintf('第 %d 條記錄,測試結果為:%d  真實結果為:%d\n',[iclassifyresult(i) labels(i)])

    [classifyresult labels(i)])

        if(classifyresult~=labels(i))

            error = error+1;

        end

end

classifyresult=classifyresult';

fprintf('分類錯誤的記錄標簽為:') 

Index=find(classifyresult~=TestLabels)

fprintf('准確率為:%f\n',1-error/(numTest))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM