交叉驗證 Cross-validation (MATLAB)


一、簡介  

  交叉驗證(Cross validation,簡稱CV)是在機器學習建立模型和驗證模型參數時常用的辦法,一般被用於評估一個機器學習模型的表現。交叉驗證的基本思想是把在某種意義下將原始數據(dataset)進行分組,一部分做為訓練集(train set),另一部分做為驗證集(validation set or test set),首先用訓練集對分類器進行訓練,再利用驗證集來測試訓練得到的模型(model),以此來做為評價分類器的性能指標。常見CV的方法有Holdout 驗證、K-fold cross-validation、留一驗證。

1. Holdout 驗證

  方法:將原始數據隨機分為兩組,一組作為訓練集,一組作為驗證集。利用訓練集訓練分類器,利用驗證集驗證模型,記錄最后的分類准確率為此Hold-Out Method下分類器的性能指標。

  優缺點:此方法的好處是處理簡單,只需要隨機把原始數據分為兩組即可。但是,Holdout 驗證嚴格意義上並不能算是CV,因為這種方法並沒有“交叉”的思想。由於是隨機地將原始數據分組,所以最后驗證集分類准確率的高低與原始數據的分組有很大的關系,所以這種方法得到的結果其實並不具有說服性。

2. K-fold cross-validation

  方法:將原始數據分割成K組數據集,每個單獨的數據集作為驗證集,其余的K-1個數據集用來訓練,交叉驗證重復K次,共得到K個模型,用這K個模型最終的驗證集的分類准確率的平均數作為此K-CV下分類器的性能指標。K一般大於等於2,實際操作時一般從3開始取,一般而言K=10是最常用的。

  優缺點:K-CV作為方法1的演進,可以有效地避免過學習以及欠學習狀態的發生,最后得到的結果也比較具有說服性。其主要缺點在於K值的選取上。

3. Leave-One_Out Cross Validation(LOO-CV)

  方法:如果設原始數據有N個樣本,那么LOO-CV就是N-CV,即每個樣本單獨作為驗證集,其余的N-1個樣本作為訓練集,所以LOO-CV會得到N個模 型,用這N個模型最終的驗證集的分類准確率的平均數作為此下LOO-CV分類器的性能指標。

  優點:相比於前面的K-CV,LOO-CV有兩個明顯的優點:1)每一回合中幾乎所有的樣本皆用於訓練模型,因此最接近原始樣本的分布,這樣評估所得的結果比較可靠。2)實驗過程中沒有隨機因素會影響實驗數據,確保實驗過程是可以被復制的。但LOO-CV的缺點則是計算成本高,因為需要建立的模型數量與原始數據樣本數量相同,當原始數據樣本數量相當多時,LOO-CV在實作上便有困難幾乎就是不顯示,除非每次訓練分類器得到模型的速度很快,或是可以用並行化計算減少計算所需的時間。

  

二、 MATLAB實踐(K-CV)

   在使用svm時,需要采用交叉驗證選擇最佳參數c和g。libsvm中的svmtrain函數內置交叉驗證選項,svmtrain的options如下:

-s svm類型:SVM模型設置類型(默認值為0)
    0:C - SVC
    1:nu - SVC
    2:one - class SVM
    3: epsilon - SVR
    4: nu - SVR
- t 核函數類型:核函數設置類型(默認值為2)
    0:線性核函數 u'v
    1:多項式核函數(r *u'v + coef0)^degree
    2:RBF 核函數 exp( -r|u - v|^2)
    3:sigmiod核函數 tanh(r * u'v + coef0)
- d degree:核函數中的 degree 參數設置(針對多項式核函數,默認值為3)
- g r(gama):核函數中的gama參數設置(針對多項式/sigmoid 核函數/RBF/,默認值為屬性數目的倒數)
- r coef0:核函數中的coef0參數設置(針對多項式/sigmoid核函數,默認值為0)
- c cost:設置 C - SVC,epsilon - SVR 和 nu - SVR的參數(默認值為1)
- n nu:設置 nu-SVC ,one - class SVM 和 nu - SVR的參數
- p epsilon:設置 epsilon - SVR 中損失函數的值(默認值為0.1- m cachesize:設置 cache 內存大小,以 MB 為單位(默認值為100)
- e eps:設置允許的終止閾值(默認值為0.001- h shrinking:是否使用啟發式,0或1(默認值為1)
- wi weight:設置第幾類的參數 C 為 weight * C(對於 C - SVC 中的 C,默認值為1)
- v n:n - fold 交互檢驗模式,n為折數,必須大於等於2

 

  其中,-v 隨機地將數據分為n部分,並計算交互檢驗准確度和均方根誤差。大致實現代碼如下:

% Cross_Validation
% K-fold cross validation
% Author Ethan
% Date 2020/4/10
% Version 1.0

fprintf('Beginning crossvalidation\n')
crossval_start = tic;

%best_accuracy = 0;
best_cv = 0;
best_c = 0;
best_g = 0;
k = 3;  % number of folds

for log2c = -5:5
    for log2g = -5:5
        cmd = ['-t 0','-v ',num2str(k),'-c ',num2str(2^log2c),'-g ',num2str(2^log2g)];
        cv = svmtrain(labels,train_matrix,cmd);

        if cv >= best_cv
            best_cv = cv;
            best_c = 2^log2c;
            best_g = 2^log2g;
        end
    end
end

crossval_elapsed = toc(crossval_start);
fprintf('SVM crosvalidation done in: %f seconds.\n',crossval_elapsed);
fprintf('Best crossval reached: %d, with cost=%d\n\n', best_cv, best_c);

%svm_params = ['-t ',num2str(0) ,' -c ', num2str(best_c),' -b 1'];
View Code

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM