LIBSVM (四) SVM 的參數優化(交叉驗證)


  CV是用來驗證分類器性能的一種統計分析方法,其基本思想是在某種意義下將原始數據進行分組,一部分作為測試集,另一部分作為驗證集;先用訓練集對分類器進行訓練,再利用驗證集來測試訓練得到的模型,以得到的分類准確率作為評價分類器的性能指標。常見的 CV 方法如下:

1.1 交叉驗證(Cross Validation,CV)

  偽代碼:

Start 
    bestAccuracy = 0;
    bestc = 0;
    bestg = 0;
    % 其中n1,n2都是預先給定的數
    for c = 2^(-n1):2(n1)
        for g = 2^(-n2):2^(n2)
                將訓練集平均分成 N 部分,設為
train(1),train(2),···,train(N)
                分別讓每一個部分作為測試及進行預測(剩下的N-1部分作為訓練集對訓練集進行訓練),取得最后
                得到的所有分類准確率的平均數,設為 cv.
                if(cv > bestAccuracy)
                    bestAccuracy = cv;bestc = c;bestg = g;
                end
            end
end
Over

   采用CV的方法,在沒有測試標簽的情況下可以找到一定意義下的最佳參數c和g。這里所說的“一定意義下”指的是此時最佳參數c和g是使得訓練集在CV思想下能夠達到最高分類准確率的參數,但不能保證會使得測試集也達到最高的分類准確率。用此方法對wine數據進行分類預測,MATLAB 實現代碼如下:

%% 交叉驗證
bestcv = 0;
bestc = 0;
bestg = 0;
for log2c = -5:5
    for log2g = -5:5
        cmd =['-v 3 -c ',num2str(2^log2c), ' -g ',num2str(2^log2g)];
        cv = svmtrain(train_wine_labels,train_wine,cmd);
        if(cv > bestcv)
        bestcv = cv;
        bestc = 2^log2c;
        bestg = 2^log2g;
        end
    end
end

fprintf('(best c = %g,g = %g,rate = %g)\n',bestc,bestg,bestcv);
cmd = ['-c ',num2str(bestc),' -g ',num2str(bestg)];
model = svmtrain(train_wine_labels,train_wine,cmd);
[predict_label,accuracy,dec_value]=svmpredict(test_wine_labels,test_wine,model);

 (best c = 2,g = 0.5,rate = 98.8764)

Accuracy = 98.8764% (88/89) (classification)

1.2  K-CV 算法(K - fold Cross Validation )

   將原始數據分成K組(一般是均分),將每個子集數據分別做一次驗證集,其余的 K-1 組子集數據作為訓練集,這樣會得到 K 個模型,用這 K 個模型最終的驗證集的分類准確率的平均值作為此 K-CV 下分類器的性能指標。K>=2,一般實際操作是取K=3,只有在原始數據集合數據量小的時候才會嘗試取 2。K - CV 可以有效地避免過學習以及欠學習狀態的發生,最后得到的結果也比較具有說服性。

  這節主要涉及到函數 SVMcgForClass 的理解,以及使用技巧

1.2.1 SVMcgForClass

   偽代碼:

Start 
    bestAccuracy = 0;
    bestc = 0;
    bestg = 0;
    % 將 c 和 g 划分網格進行搜索
    for c = 2^(-n1):2(n1)
        for g = 2^(-n2):2^(n2)
               %利用 K-CV 方法
將train平均分成 K 組,
記train(1),train(2),···,train(K)
相應的標簽也要分離出來,
記為 train_label(1),train_label(2),···,train_label(k),
            for  run = 1:k  %讓train(run),作為驗證集,其他作為訓練集,記錄此時acc(run)
            end
            cv = (acc(1)+acc(2)+```+acc(K))/k;
            if(cv > bestAccuracy)
                bestAccuracy = cv;bestc = c;bestg = g;
            end
        end
end
Over

函數接口及輸入參數解析

[bestacc,bestc,bestg]=SVMcgForClass(train_label,train,cmin,cmax,gmin,gmax,v,cstep,gstep,accstep)
train_label :訓練集標簽
train:訓練集
cmin:懲罰參數c的變化范圍(以2為底的冪指數后),即c_min=2……(cmin),默認值為-5;
cmax:懲罰參數c的變化范圍(以2為底的冪指數后),即c_max=2……(cmax),默認值為5;
gmin:參數g的變化范圍最小值(以2為底的冪指數后),即g_min=2……(gmin),默認值為-5;
gmax:參數g的變化范圍最大值(以2為底的冪指數后),即g_max=2……(gmax),默認值為5;
v:CV的參數,即給測試集分幾部分進行CV。默認值為3.
cstep:參數c步進的大小(取以2為底的冪指數后),默認值為1.
gstep:參數g步進的大小(取以2為底的冪指數后),默認值為1.
accstep:最后顯示准確率圖時的步進大小

 以上參數只有train_label和train是必須輸入的,其他的可不輸入采用默認值。

1.2.2 代碼實現

% about the parameters of SVMcg 
if nargin < 10
    accstep = 4.5;
end
if nargin < 8
    cstep = 0.8;
    gstep = 0.8;
end
if nargin < 7
    v = 5;
end
if nargin < 5
    gmax = 8;
    gmin = -8;
end
if nargin < 3
    cmax = 8;
    cmin = -8;
end
% X:c Y:g cg:CVaccuracy
[X,Y] = meshgrid(cmin:cstep:cmax,gmin:gstep:gmax);
[m,n] = size(X);
cg = zeros(m,n);

eps = 10^(-4);

% record acc with different c & g,and find the bestacc with the smallest c
bestc = 1;
bestg = 0.1;
bestacc = 0;
basenum = 2;
for i = 1:m
    for j = 1:n
        cmd = ['-v ',num2str(v),' -c ',num2str( basenum^X(i,j) ),' -g ',num2str( basenum^Y(i,j) )];
        cg(i,j) = svmtrain(train_label, train, cmd);
        
        if cg(i,j) <= 55
            continue;
        end
        
        if cg(i,j) > bestacc
            bestacc = cg(i,j);
            bestc = basenum^X(i,j);
            bestg = basenum^Y(i,j);
        end        
        
        if abs( cg(i,j)-bestacc )<=eps && bestc > basenum^X(i,j) 
            bestacc = cg(i,j);
            bestc = basenum^X(i,j);
            bestg = basenum^Y(i,j);
        end        
        
    end
end
% to draw the acc with different c & g
figure;
[C,h] = contour(X,Y,cg,70:accstep:100);
clabel(C,h,'Color','r');
xlabel('log2c','FontSize',12);
ylabel('log2g','FontSize',12);
firstline = 'SVC參數選擇結果圖(等高線圖)[GridSearchMethod]'; 
secondline = ['Best c=',num2str(bestc),' g=',num2str(bestg), ...
    ' CVAccuracy=',num2str(bestacc),'%'];
title({firstline;secondline},'Fontsize',12);
grid on; 

figure;
meshc(X,Y,cg);
% mesh(X,Y,cg);
% surf(X,Y,cg);
axis([cmin,cmax,gmin,gmax,30,100]);
xlabel('log2c','FontSize',12);
ylabel('log2g','FontSize',12);
zlabel('Accuracy(%)','FontSize',12);
firstline = 'SVC參數選擇結果圖(3D視圖)[GridSearchMethod]'; 
secondline = ['Best c=',num2str(bestc),' g=',num2str(bestg), ...
    ' CVAccuracy=',num2str(bestacc),'%'];
title({firstline;secondline},'Fontsize',12);

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM