KNN的改進算法、剪輯近鄰法與壓縮近鄰法的MATLAB實現


KNN(K - Nearest Neighbor)分類算法是模式識別領域的一個簡單分類方法。KNN算法的核心思想是,如果一個樣本在特征空間中的k個最相鄰的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別,並具有這個類別上樣本的特性。該方法在確定分類決策上只依據最鄰近的k個樣本的類別來決定待分樣本所屬的類別。

首先,knn算法比較適合只有兩類樣本的簡單分類問題,這樣當n為奇數時就可以少數服從多數達到分類目的。但是當類別數量大於2時,設定n為奇數已經不能避免投票數量相同的問題。這種情況在我的理解下有兩種解決思路:

思路A:當對運算效率要求較高而對分類結果要求不高時,可以選擇投票數量相同的樣本類別中先被查詢到的一個;

思路B:當對運算效率要求不高而對分類結果要求較高時,可以改變k值再次進行投票,例如,在k0=5時出現了投票數量相同的樣本類別,可以令k1=k0-1再次進行判斷,若仍存在投票數量相同的樣本類別,可以繼續令k2=k1-1再次進行判斷,這樣在km=1時會得到唯一答案。這里不建議增大k值再次投票,因為我擔心會陷入更麻煩的情況。

KNN作為機器學習的入門算法,存在效率低、過度依賴訓練數據的缺點,並在處理較大數據時可能引起維數災難。需要謹慎考慮后再選擇。

在普通的KNN算法下,當k個最近鄰樣本進行投票時,存在投票數量相同的樣本類別,即MaxValue的長度不為1時程序會報錯。根據思路A,可將MaxValue改為MaxValue(1)。下面的算法在這里根據思路B進行改進,具體方法是減小k值遞歸調用knn函數。

同時為了滿足壓縮近鄰法的需要,處理了當訓練集數據不足K個時出現的問題。解決方法是,當訓練集數據不足K個時,令K為訓練集數據的個數。

改進后的KNN算法:

function y = knn(trainData, sample_label, testData, k)

%KNN k-Nearest Neighbors Algorithm.
%
%   INPUT:  trainData:       training sample Data, M-by-N matrix.
%           sample_label:    training sample labels, 1-by-N row vector.
%           testData:        testing sample Data, M-by-N_test matrix.
%           K:               the k in k-Nearest Neighbors
%
%   OUTPUT: y    : predicted labels, 1-by-N_test row vector.
%
% Author: Sophia_Dz

if length(trainData) < k
    k = length(trainData);
end
 
[M_train, N] = size(trainData);
[M_test, ~] = size(testData);
 
%calculate the distance between testData and trainData

Dis = zeros(M_train,1);
class_test = zeros(M_test,1);
for n = 1:M_test
    for i = 1:M_train
        distance1 = 0;
        for j = 1:N
            distance1 = (testData(n,j) - trainData(i,j)).^2 + distance1;
        end
        Dis(i,1) = distance1.^0.5;
    end
    %find the k nearest neighbor
    [~, index] = sort(Dis);
    temp=1:k;
    for i = 1:k
        temp(i) = sample_label(index(i));
    end
    table = tabulate(temp);
    MaxCount=max(table(:,2,:));
    [row,~]=find(table(:,2,:)==MaxCount);
    MaxValue=table(row,1);
    if length(MaxValue) ~= 1
        MaxValue = knn(trainData, sample_label, testData(n,:), k-1);
    end
    class_test(n) = MaxValue;
end
y = class_test;

下面是剪輯近鄰法與壓縮近鄰法的MATLAB實現:

首先設定參數:

%% parameter determination
clear;
% dataset parameter
dataset_len=400;
dataset_proportion=[8,2];
attribute_num=2;
% knn parameter
k=5;
% edit parameter
m=4;
s=4;

准備數據,這里使用MATLAB生成服從正態分布的三組數據,它們的均值不同:

%% dataset load
dataset_class_len=fix(dataset_len/3);
dataset=[
    4*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class A attribute x
    2*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class A attribute y
    1*ones(dataset_class_len,1);                               % class A label
    2*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class B attribute x
    4*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class B attribute y
    2*ones(dataset_class_len,1);                               % class B label
    5*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class C attribute x
    5*ones(dataset_class_len,1)+randn(dataset_class_len,1),... % class C attribute y
    3*ones(dataset_class_len,1)                                % class C label
    ];

將數據划分為訓練集與測試集:

%% preprocess data
% order disrupt
rand_class_index=randperm(size(dataset,1));
dataset=dataset(rand_class_index,:);
% train dataset and test dataset
dataset_train_len=fix(dataset_len*(dataset_proportion(1)/sum(dataset_proportion)));
dataset_train=dataset(1:dataset_train_len,:);
dataset_test=dataset(dataset_train_len+1:end,:);
% attribute and label
dataset_train_attribute=dataset_train(:,1:attribute_num);
dataset_train_label=dataset_train(:,attribute_num+1);
dataset_test_attribute=dataset_test(:,1:attribute_num);
dataset_test_label=dataset_test(:,attribute_num+1);

將訓練集的樣本可視化:

%% train dataset visualization
data_vis=dataset_train;
figure(1);
for n=1:length(data_vis)
    X=data_vis(n,1);
    Y=data_vis(n,2);
    if data_vis(n,attribute_num+1)==1
        color='red';
    elseif data_vis(n,attribute_num+1)==2
        color='green';
    elseif data_vis(n,attribute_num+1)==3
        color='blue';
    end
    plot(X,Y,'+','Color',color);
    hold on;
end

圖1 未經處理的訓練集樣本

分類:

%% knn
classification_result=knn(dataset_train_attribute,dataset_train_label,dataset_test_attribute,k);

計算根據此訓練集的分類正確率為:0.8481:

%% correct rate
error_count=0;
for n=1:length(classification_result)
    if dataset_test_label(n)~=classification_result(n)
        error_count=error_count+1;
    end
end
correct_rate=1-error_count/length(classification_result);

剪輯近鄰法:

剪輯近鄰法的基本思想是,當不同類別的樣本在分布上有交迭部分的,分類的錯誤率主要來自處於交迭區中的樣本,通過剪輯去除大部分交迭區中的樣本。

1. 將訓練集隨機划分成s組;

2. 其中i組作為訓練集,i+1組樣本作為測試集,用訓練集中的樣本對測試集中的樣本進行最近鄰分類,如果類別不同,則從測試集中分類錯誤的樣本去除

3. 若達到m次沒有新的樣本被去除,剪輯完成

%% edit nearest neighbor
% init
dataset_edit=dataset_train;
loop=0;
add_old=1;
while loop<m
    rand_class_index=ceil(unifrnd(0,s,length(dataset_edit),1));
    dataset_new=zeros(length(dataset_edit),attribute_num+1);
    add=1;
    for i=1:s
        % train set and test set
        test_set=dataset_edit((rand_class_index==i),:);
        if i<s
            j=i+1;
        else
            j=1;
        end
        train_set=dataset_edit((rand_class_index==j),:);
        train_set_attribute=train_set(:,1:attribute_num);
        train_set_label=train_set(:,attribute_num+1);
        test_set_attribute=test_set(:,1:attribute_num);
        test_set_label=test_set(:,attribute_num+1);
        % classification
        result=knn(train_set_attribute,train_set_label,test_set_attribute,k);
        for num=1:length(result)
            if(result(num)==test_set_label(num))
                dataset_new(add,:)=test_set(num,:);
                add=add+1;
            end
        end
    end
    dataset_edit=dataset_new(1:add-1,:);
    if(add==add_old)
        loop=loop+1;
    end
    add_old=add;
end

剪輯后的訓練集:

%% edit data visualization
data_vis=dataset_edit;
figure(2);
for n=1:length(data_vis)
    X=data_vis(n,1);
    Y=data_vis(n,2);
    if data_vis(n,attribute_num+1)==1
        color='red';
    elseif data_vis(n,attribute_num+1)==2
        color='green';
    elseif data_vis(n,attribute_num+1)==3
        color='blue';
    end
    plot(X,Y,'+','Color',color);
    hold on;
end

圖2 剪輯處理的訓練集樣本

分類:

%% edit knn
dataset_edit_attribute=dataset_edit(:,1:attribute_num);
dataset_edit_label=dataset_edit(:,attribute_num+1);
classification_result_edit=knn(dataset_edit_attribute,dataset_edit_label,dataset_test_attribute,k);

計算根據此訓練集的分類正確率為:0.8734:

%% edit correct rate
error_count=0;
for n=1:length(classification_result_edit)
    if dataset_test_label(n)~=classification_result_edit(n)
        error_count=error_count+1;
    end
end
correct_rate_edit=1-error_count/length(classification_result_edit);

壓縮近鄰法:

壓縮近鄰法壓縮樣本的思想,它利用現有樣本集,逐漸生成一個新的樣本集。使該樣本集在保留最少量樣本的條件下, 仍能對原有樣本的全部用最近鄰法正確分類,那么該樣本集也就能對待識別樣本進行分類, 並保持正常識別率。

1. 對初始訓練集,將其划分為兩個部分Store和Garbbag,初始Store樣本集合為空。

2. 從初始訓練集中隨機選擇一個樣本放入Store中,其它樣本放入Garbbag中,用其對Garbbag中的每一個樣本進行分類。若樣本i能夠被正確分類,則將其放回到Garbbag中;否則將其加入到Store中;

3. 重復上述過程,直到Garbbag中所有樣本都能正確分類為止。

%% condense nearest neighbor
% init
dataset_condense=dataset_train;
store=zeros(size(dataset_condense));
store_count=0;
garbbag=dataset_condense;
garbbag_count=length(dataset_condense);
add=garbbag_count;
% move one
store(store_count+1,:)=garbbag(add,:);
garbbag(add,:)=[];
store_count=store_count+1;
garbbag_count=garbbag_count-1;
add=add-1;
store_attribute=store(:,1:attribute_num);
store_label=store(:,attribute_num+1);
garbbag_attribute=garbbag(:,1:attribute_num);
garbbag_label=garbbag(:,attribute_num+1);
change_flag=1;
while change_flag==1
    change_flag=0;
    add=garbbag_count;
    while add>0
        result=knn(store_attribute,store_label,garbbag_attribute(add,:),k);
        if result~=garbbag_label(add)
            change_flag=1;
            store(store_count+1,:)=garbbag(add,:);
            garbbag(add,:)=[];
            store_count=store_count+1;
            garbbag_count=garbbag_count-1;
            add=add-1;
            store_attribute=store(:,1:attribute_num);
            store_label=store(:,attribute_num+1);
            garbbag_attribute=garbbag(:,1:attribute_num);
            garbbag_label=garbbag(:,attribute_num+1);
        end
        add=add-1;
    end
end
dataset_condense=store;

壓縮后的訓練集:

%% condense data visualization
data_vis=dataset_condense;
figure(3);
for n=1:length(data_vis)
    X=data_vis(n,1);
    Y=data_vis(n,2);
    if data_vis(n,attribute_num+1)==1
        color='red';
    elseif data_vis(n,attribute_num+1)==2
        color='green';
    elseif data_vis(n,attribute_num+1)==3
        color='blue';
    end
    plot(X,Y,'+','Color',color);
    hold on;
end

圖3 壓縮處理的訓練集樣本

分類:

%% condense knn
dataset_condense_attribute=dataset_condense(:,1:attribute_num);
dataset_condense_label=dataset_condense(:,attribute_num+1);
classification_result_condense=knn(dataset_condense_attribute,dataset_condense_label,dataset_test_attribute,k);

計算根據此訓練集的分類正確率為:0.8228:

%% condense correct rate
error_count=0;
for n=1:length(classification_result_condense)
    if dataset_test_label(n)~=classification_result_condense(n)
        error_count=error_count+1;
    end
end
correct_rate_condense=1-error_count/length(classification_result_condense);

將剪輯后的樣本壓縮處理:

%% edit and condense nearest neighbor
% init
dataset_ec=dataset_edit;
store=zeros(size(dataset_ec));
store_count=0;
garbbag=dataset_ec;
garbbag_count=length(dataset_ec);
add=garbbag_count;
% move one
store(store_count+1,:)=garbbag(add,:);
garbbag(add,:)=[];
store_count=store_count+1;
garbbag_count=garbbag_count-1;
add=add-1;
store_attribute=store(:,1:attribute_num);
store_label=store(:,attribute_num+1);
garbbag_attribute=garbbag(:,1:attribute_num);
garbbag_label=garbbag(:,attribute_num+1);
change_flag=1;
while change_flag==1
    change_flag=0;
    add=garbbag_count;
    while add>0
        result=knn(store_attribute,store_label,garbbag_attribute(add,:),k);
        if result~=garbbag_label(add)
            change_flag=1;
            store(store_count+1,:)=garbbag(add,:);
            garbbag(add,:)=[];
            store_count=store_count+1;
            garbbag_count=garbbag_count-1;
            add=add-1;
            store_attribute=store(:,1:attribute_num);
            store_label=store(:,attribute_num+1);
            garbbag_attribute=garbbag(:,1:attribute_num);
            garbbag_label=garbbag(:,attribute_num+1);
        end
        add=add-1;
    end
end
dataset_ec=store;

壓縮處理剪輯后的樣本:

%% edit and condense data visualization
data_vis=dataset_ec;
figure(4);
for n=1:length(data_vis)
    X=data_vis(n,1);
    Y=data_vis(n,2);
    if data_vis(n,attribute_num+1)==1
        color='red';
    elseif data_vis(n,attribute_num+1)==2
        color='green';
    elseif data_vis(n,attribute_num+1)==3
        color='blue';
    end
    plot(X,Y,'+','Color',color);
    hold on;
end

圖4 剪輯壓縮處理的訓練集樣本

分類:

%% edit and condense knn
dataset_ec_attribute=dataset_ec(:,1:attribute_num);
dataset_ec_label=dataset_ec(:,attribute_num+1);
classification_result_ec=knn(dataset_ec_attribute,dataset_ec_label,dataset_test_attribute,k);

計算根據此訓練集的分類正確率為:0.8228:

%% edit and condense correct rate
error_count=0;
for n=1:length(classification_result_ec)
    if dataset_test_label(n)~=classification_result_ec(n)
        error_count=error_count+1;
    end
end
correct_rate_ec=1-error_count/length(classification_result_ec);

經過多次實驗得到以下結論:

1. 剪輯處理能夠去除分類邊界的樣本;

2. 壓縮近鄰主要去除樣本中靠近中心的樣本;

3. 剪輯近鄰法可以去除部分樣本並在一定程度上提高分類正確率;

4. 壓縮近鄰法可以去除大量樣本。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM