KNN(K - Nearest Neighbor)分類算法是模式識別領域的一個簡單分類方法。KNN算法的核心思想是,如果一個樣本在特征空間中的k個最相鄰的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,並具有這個類別上樣本的特性。該方法在確定分類決策上只依據最鄰近的k個樣本的類別來決定待分樣本所屬的類別。
首先,knn算法比較適合只有兩類樣本的簡單分類問題,這樣當n為奇數時就可以少數服從多數達到分類目的。但是當類別數量大於2時,設定n為奇數已經不能避免投票數量相同的問題。這種情況在我的理解下有兩種解決思路:
思路A:當對運算效率要求較高而對分類結果要求不高時,可以選擇投票數量相同的樣本類別中先被查詢到的一個;
思路B:當對運算效率要求不高而對分類結果要求較高時,可以改變k值再次進行投票,例如,在k0=5時出現了投票數量相同的樣本類別,可以令k1=k0-1再次進行判斷,若仍存在投票數量相同的樣本類別,可以繼續令k2=k1-1再次進行判斷,這樣在km=1時會得到唯一答案。這里不建議增大k值再次投票,因為我擔心會陷入更麻煩的情況。
KNN作為機器學習的入門算法,存在效率低、過度依賴訓練數據的缺點,並在處理較大數據時可能引起維數災難。需要謹慎考慮后再選擇。
在普通的KNN算法下,當k個最近鄰樣本進行投票時,存在投票數量相同的樣本類別,即MaxValue的長度不為1時程序會報錯。根據思路A,可將MaxValue改為MaxValue(1)。下面的算法在這里根據思路B進行改進,具體方法是減小k值遞歸調用knn函數。
同時為了滿足壓縮近鄰法的需要,處理了當訓練集數據不足K個時出現的問題。解決方法是,當訓練集數據不足K個時,令K為訓練集數據的個數。
改進后的KNN算法:
function y = knn(trainData, sample_label, testData, k) %KNN k-Nearest Neighbors Algorithm. % % INPUT: trainData: training sample Data, M-by-N matrix. % sample_label: training sample labels, 1-by-N row vector. % testData: testing sample Data, M-by-N_test matrix. % K: the k in k-Nearest Neighbors % % OUTPUT: y : predicted labels, 1-by-N_test row vector. % % Author: Sophia_Dz if length(trainData) < k k = length(trainData); end [M_train, N] = size(trainData); [M_test, ~] = size(testData); %calculate the distance between testData and trainData Dis = zeros(M_train,1); class_test = zeros(M_test,1); for n = 1:M_test for i = 1:M_train distance1 = 0; for j = 1:N distance1 = (testData(n,j) - trainData(i,j)).^2 + distance1; end Dis(i,1) = distance1.^0.5; end %find the k nearest neighbor [~, index] = sort(Dis); temp=1:k; for i = 1:k temp(i) = sample_label(index(i)); end table = tabulate(temp); MaxCount=max(table(:,2,:)); [row,~]=find(table(:,2,:)==MaxCount); MaxValue=table(row,1); if length(MaxValue) ~= 1 MaxValue = knn(trainData, sample_label, testData(n,:), k-1); end class_test(n) = MaxValue; end y = class_test;
下面是剪輯近鄰法與壓縮近鄰法的MATLAB實現:
首先設定參數:
%% parameter determination clear; % dataset parameter dataset_len=400; dataset_proportion=[8,2]; attribute_num=2; % knn parameter k=5; % edit parameter m=4; s=4;
准備數據,這里使用MATLAB生成服從正態分布的三組數據,它們的均值不同:
%% dataset load dataset_class_len=fix(dataset_len/3); dataset=[ 4*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class A attribute x 2*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class A attribute y 1*ones(dataset_class_len,1); % class A label 2*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class B attribute x 4*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class B attribute y 2*ones(dataset_class_len,1); % class B label 5*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class C attribute x 5*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class C attribute y 3*ones(dataset_class_len,1) % class C label ];
將數據划分為訓練集與測試集:
%% preprocess data % order disrupt rand_class_index=randperm(size(dataset,1)); dataset=dataset(rand_class_index,:); % train dataset and test dataset dataset_train_len=fix(dataset_len*(dataset_proportion(1)/sum(dataset_proportion))); dataset_train=dataset(1:dataset_train_len,:); dataset_test=dataset(dataset_train_len+1:end,:); % attribute and label dataset_train_attribute=dataset_train(:,1:attribute_num); dataset_train_label=dataset_train(:,attribute_num+1); dataset_test_attribute=dataset_test(:,1:attribute_num); dataset_test_label=dataset_test(:,attribute_num+1);
將訓練集的樣本可視化:
%% train dataset visualization data_vis=dataset_train; figure(1); for n=1:length(data_vis) X=data_vis(n,1); Y=data_vis(n,2); if data_vis(n,attribute_num+1)==1 color='red'; elseif data_vis(n,attribute_num+1)==2 color='green'; elseif data_vis(n,attribute_num+1)==3 color='blue'; end plot(X,Y,'+','Color',color); hold on; end
圖1 未經處理的訓練集樣本
分類:
%% knn
classification_result=knn(dataset_train_attribute,dataset_train_label,dataset_test_attribute,k);
計算根據此訓練集的分類正確率為:0.8481:
%% correct rate error_count=0; for n=1:length(classification_result) if dataset_test_label(n)~=classification_result(n) error_count=error_count+1; end end correct_rate=1-error_count/length(classification_result);
剪輯近鄰法:
剪輯近鄰法的基本思想是,當不同類別的樣本在分布上有交迭部分的,分類的錯誤率主要來自處於交迭區中的樣本,通過剪輯去除大部分交迭區中的樣本。
1. 將訓練集隨機划分成s組;
2. 其中i組作為訓練集,i+1組樣本作為測試集,用訓練集中的樣本對測試集中的樣本進行最近鄰分類,如果類別不同,則從測試集中分類錯誤的樣本去除;
3. 若達到m次沒有新的樣本被去除,剪輯完成。
%% edit nearest neighbor % init dataset_edit=dataset_train; loop=0; add_old=1; while loop<m rand_class_index=ceil(unifrnd(0,s,length(dataset_edit),1)); dataset_new=zeros(length(dataset_edit),attribute_num+1); add=1; for i=1:s % train set and test set test_set=dataset_edit((rand_class_index==i),:); if i<s j=i+1; else j=1; end train_set=dataset_edit((rand_class_index==j),:); train_set_attribute=train_set(:,1:attribute_num); train_set_label=train_set(:,attribute_num+1); test_set_attribute=test_set(:,1:attribute_num); test_set_label=test_set(:,attribute_num+1); % classification result=knn(train_set_attribute,train_set_label,test_set_attribute,k); for num=1:length(result) if(result(num)==test_set_label(num)) dataset_new(add,:)=test_set(num,:); add=add+1; end end end dataset_edit=dataset_new(1:add-1,:); if(add==add_old) loop=loop+1; end add_old=add; end
剪輯后的訓練集:
%% edit data visualization data_vis=dataset_edit; figure(2); for n=1:length(data_vis) X=data_vis(n,1); Y=data_vis(n,2); if data_vis(n,attribute_num+1)==1 color='red'; elseif data_vis(n,attribute_num+1)==2 color='green'; elseif data_vis(n,attribute_num+1)==3 color='blue'; end plot(X,Y,'+','Color',color); hold on; end
圖2 剪輯處理的訓練集樣本
分類:
%% edit knn dataset_edit_attribute=dataset_edit(:,1:attribute_num); dataset_edit_label=dataset_edit(:,attribute_num+1); classification_result_edit=knn(dataset_edit_attribute,dataset_edit_label,dataset_test_attribute,k);
計算根據此訓練集的分類正確率為:0.8734:
%% edit correct rate error_count=0; for n=1:length(classification_result_edit) if dataset_test_label(n)~=classification_result_edit(n) error_count=error_count+1; end end correct_rate_edit=1-error_count/length(classification_result_edit);
壓縮近鄰法:
壓縮近鄰法壓縮樣本的思想,它利用現有樣本集,逐漸生成一個新的樣本集。使該樣本集在保留最少量樣本的條件下, 仍能對原有樣本的全部用最近鄰法正確分類,那么該樣本集也就能對待識別樣本進行分類, 並保持正常識別率。
1. 對初始訓練集,將其划分為兩個部分Store和Garbbag,初始Store樣本集合為空。
2. 從初始訓練集中隨機選擇一個樣本放入Store中,其它樣本放入Garbbag中,用其對Garbbag中的每一個樣本進行分類。若樣本i能夠被正確分類,則將其放回到Garbbag中;否則將其加入到Store中;
3. 重復上述過程,直到Garbbag中所有樣本都能正確分類為止。
%% condense nearest neighbor % init dataset_condense=dataset_train; store=zeros(size(dataset_condense)); store_count=0; garbbag=dataset_condense; garbbag_count=length(dataset_condense); add=garbbag_count; % move one store(store_count+1,:)=garbbag(add,:); garbbag(add,:)=[]; store_count=store_count+1; garbbag_count=garbbag_count-1; add=add-1; store_attribute=store(:,1:attribute_num); store_label=store(:,attribute_num+1); garbbag_attribute=garbbag(:,1:attribute_num); garbbag_label=garbbag(:,attribute_num+1); change_flag=1; while change_flag==1 change_flag=0; add=garbbag_count; while add>0 result=knn(store_attribute,store_label,garbbag_attribute(add,:),k); if result~=garbbag_label(add) change_flag=1; store(store_count+1,:)=garbbag(add,:); garbbag(add,:)=[]; store_count=store_count+1; garbbag_count=garbbag_count-1; add=add-1; store_attribute=store(:,1:attribute_num); store_label=store(:,attribute_num+1); garbbag_attribute=garbbag(:,1:attribute_num); garbbag_label=garbbag(:,attribute_num+1); end add=add-1; end end dataset_condense=store;
壓縮后的訓練集:
%% condense data visualization data_vis=dataset_condense; figure(3); for n=1:length(data_vis) X=data_vis(n,1); Y=data_vis(n,2); if data_vis(n,attribute_num+1)==1 color='red'; elseif data_vis(n,attribute_num+1)==2 color='green'; elseif data_vis(n,attribute_num+1)==3 color='blue'; end plot(X,Y,'+','Color',color); hold on; end
圖3 壓縮處理的訓練集樣本
分類:
%% condense knn dataset_condense_attribute=dataset_condense(:,1:attribute_num); dataset_condense_label=dataset_condense(:,attribute_num+1); classification_result_condense=knn(dataset_condense_attribute,dataset_condense_label,dataset_test_attribute,k);
計算根據此訓練集的分類正確率為:0.8228:
%% condense correct rate error_count=0; for n=1:length(classification_result_condense) if dataset_test_label(n)~=classification_result_condense(n) error_count=error_count+1; end end correct_rate_condense=1-error_count/length(classification_result_condense);
將剪輯后的樣本壓縮處理:
%% edit and condense nearest neighbor % init dataset_ec=dataset_edit; store=zeros(size(dataset_ec)); store_count=0; garbbag=dataset_ec; garbbag_count=length(dataset_ec); add=garbbag_count; % move one store(store_count+1,:)=garbbag(add,:); garbbag(add,:)=[]; store_count=store_count+1; garbbag_count=garbbag_count-1; add=add-1; store_attribute=store(:,1:attribute_num); store_label=store(:,attribute_num+1); garbbag_attribute=garbbag(:,1:attribute_num); garbbag_label=garbbag(:,attribute_num+1); change_flag=1; while change_flag==1 change_flag=0; add=garbbag_count; while add>0 result=knn(store_attribute,store_label,garbbag_attribute(add,:),k); if result~=garbbag_label(add) change_flag=1; store(store_count+1,:)=garbbag(add,:); garbbag(add,:)=[]; store_count=store_count+1; garbbag_count=garbbag_count-1; add=add-1; store_attribute=store(:,1:attribute_num); store_label=store(:,attribute_num+1); garbbag_attribute=garbbag(:,1:attribute_num); garbbag_label=garbbag(:,attribute_num+1); end add=add-1; end end dataset_ec=store;
壓縮處理剪輯后的樣本:
%% edit and condense data visualization data_vis=dataset_ec; figure(4); for n=1:length(data_vis) X=data_vis(n,1); Y=data_vis(n,2); if data_vis(n,attribute_num+1)==1 color='red'; elseif data_vis(n,attribute_num+1)==2 color='green'; elseif data_vis(n,attribute_num+1)==3 color='blue'; end plot(X,Y,'+','Color',color); hold on; end
圖4 剪輯壓縮處理的訓練集樣本
分類:
%% edit and condense knn dataset_ec_attribute=dataset_ec(:,1:attribute_num); dataset_ec_label=dataset_ec(:,attribute_num+1); classification_result_ec=knn(dataset_ec_attribute,dataset_ec_label,dataset_test_attribute,k);
計算根據此訓練集的分類正確率為:0.8228:
%% edit and condense correct rate error_count=0; for n=1:length(classification_result_ec) if dataset_test_label(n)~=classification_result_ec(n) error_count=error_count+1; end end correct_rate_ec=1-error_count/length(classification_result_ec);
經過多次實驗得到以下結論:
1. 剪輯處理能夠去除分類邊界的樣本;
2. 壓縮近鄰主要去除樣本中靠近中心的樣本;
3. 剪輯近鄰法可以去除部分樣本並在一定程度上提高分類正確率;
4. 壓縮近鄰法可以去除大量樣本。