KNN(K-Nearest Neighbor)算法Matlab實現


%實現KNN算法  
%%算法描述  
%1、初始化訓練集和類別;  
%2、計算測試集樣本與訓練集樣本的歐氏距離;  
%3、根據歐氏距離大小對訓練集樣本進行升序排序;  
%4、選取歐式距離最小的前K個訓練樣本,統計其在各類別中的頻率;  
%5、返回頻率最大的類別,即測試集樣本屬於該類別。  
close all;  
clc;  
  
%%算法實現  
%step1、初始化訓練集、測試集、K值  
%創建一個三維矩陣,二維表示同一類下的二維坐標點,第三維表示類別  
  
trainData1=[0 0;0.1 0.3;0.2 0.1;0.2 0.2];%第一類訓練數據  
trainData2=[1 0;1.1 0.3;1.2 0.1;1.2 0.2];%第二類訓練數據  
trainData3=[0 1;0.1 1.3;0.2 1.1;0.2 1.2];%第三類訓練數據  
trainData(:,:,1)=trainData1;%設置第一類測試數據  
trainData(:,:,2)=trainData2;%設置第二類測試數據  
trainData(:,:,3)=trainData3;%設置第三類測試數據  
  
trainDim=size(trainData);%獲取訓練集的維數  
  
testData=[1.6 0.3];%設置1個測試點  
  
K=7;  
  
%%分別計算測試集中各個點與每個訓練集中的點的歐氏距離  
%把測試點擴展成矩陣  
testData_rep=repmat(testData,4,1);  
%設置三個二維矩陣存放測試集與測試點的擴展矩陣的差值平方  
  
%diff1=zero(trainDim(1),trianDim(2));  
%diff2=zero(trainDim(1),trianDim(2));  
%diff3=zero(trainDim(1),trianDim(2));  
  
for i=1:trainDim(3)  
    diff1=(trainData(:,:,1)-testData_rep).^2;  
    diff2=(trainData(:,:,2)-testData_rep).^2;  
    diff3=(trainData(:,:,3)-testData_rep).^2;  
end  
  
%設置三個一維數組存放歐式距離  
distance1=(diff1(:,1)+diff1(:,2)).^0.5;  
distance2=(diff2(:,1)+diff2(:,2)).^0.5;  
distance3=(diff3(:,1)+diff3(:,2)).^0.5;  
  
%將三個一維數組合成一個二維矩陣  
temp=[distance1 distance2 distance3];  
%將這個二維矩陣轉換為一維數組  
distance=reshape(temp,1,3*4);  
%對距離進行排序  
distance_sort=sort(distance);  
%用一個循環尋找最小的K個距離里面那個類里出現的頻率最高,並返回該類  
num1=0;%第一類出現的次數  
num2=0;%第二類出現的次數  
num3=0;%第三類出現的次數  
sum=0;%sum1,sum2,sum3的和  
for i=1:K  
    for j=1:4  
        if distance1(j)==distance_sort(i)  
            num1=num1+1;  
        end  
        if distance2(j)==distance_sort(i)  
            num2=num2+1;  
        end  
        if distance3(j)==distance_sort(i)  
            num3=num3+1;  
        end  
    end  
    sum=num1+num2+num3;  
    if sum>=K  
        break;  
    end  
end  
  
class=[num1 num2 num3];  
  
classname=find(class(1,:)==max(class));  
  
fprintf('測試點(%f %f)屬於第%d類',testData(1),testData(2),classname);  
  
%%使用繪圖將訓練集點和測試集點繪畫出來  
figure(1);  
hold on;  
for i=1:4  
    plot(trainData1(i,1),trainData1(i,2),'*');  
    plot(trainData2(i,1),trainData2(i,2),'o');  
    plot(trainData3(i,1),trainData3(i,2),'>');  
end  
plot(testData(1),testData(2),'x');  
text(0.1,0.1,'第一類');  
text(1.1,0.1,'第二類');  
text(0.1,1,'第三類');  

轉載於https://blog.csdn.net/queyuze/article/details/70195087


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM