一、簡介
交叉驗證(Cross validation,簡稱CV)是在機器學習建立模型和驗證模型參數時常用的辦法,一般被用於評估一個機器學習模型的表現。交叉驗證的基本思想是把在某種意義下將原始數據(dataset)進行分組,一部分做為訓練集(train set),另一部分做為驗證集(validation set or test set),首先用訓練集對分類器進行訓練,再利用驗證集來測試訓練得到的模型(model),以此來做為評價分類器的性能指標。常見CV的方法有Holdout 驗證、K-fold cross-validation、留一驗證。
1. Holdout 驗證
方法:將原始數據隨機分為兩組,一組作為訓練集,一組作為驗證集。利用訓練集訓練分類器,利用驗證集驗證模型,記錄最后的分類准確率為此Hold-Out Method下分類器的性能指標。
優缺點:此方法的好處是處理簡單,只需要隨機把原始數據分為兩組即可。但是,Holdout 驗證嚴格意義上並不能算是CV,因為這種方法並沒有“交叉”的思想。由於是隨機地將原始數據分組,所以最后驗證集分類准確率的高低與原始數據的分組有很大的關系,所以這種方法得到的結果其實並不具有說服性。
2. K-fold cross-validation
方法:將原始數據分割成K組數據集,每個單獨的數據集作為驗證集,其余的K-1個數據集用來訓練,交叉驗證重復K次,共得到K個模型,用這K個模型最終的驗證集的分類准確率的平均數作為此K-CV下分類器的性能指標。K一般大於等於2,實際操作時一般從3開始取,一般而言K=10是最常用的。
優缺點:K-CV作為方法1的演進,可以有效地避免過學習以及欠學習狀態的發生,最后得到的結果也比較具有說服性。其主要缺點在於K值的選取上。
3. Leave-One_Out Cross Validation(LOO-CV)
方法:如果設原始數據有N個樣本,那么LOO-CV就是N-CV,即每個樣本單獨作為驗證集,其余的N-1個樣本作為訓練集,所以LOO-CV會得到N個模 型,用這N個模型最終的驗證集的分類准確率的平均數作為此下LOO-CV分類器的性能指標。
優點:相比於前面的K-CV,LOO-CV有兩個明顯的優點:1)每一回合中幾乎所有的樣本皆用於訓練模型,因此最接近原始樣本的分布,這樣評估所得的結果比較可靠。2)實驗過程中沒有隨機因素會影響實驗數據,確保實驗過程是可以被復制的。但LOO-CV的缺點則是計算成本高,因為需要建立的模型數量與原始數據樣本數量相同,當原始數據樣本數量相當多時,LOO-CV在實作上便有困難幾乎就是不顯示,除非每次訓練分類器得到模型的速度很快,或是可以用並行化計算減少計算所需的時間。
二、 MATLAB實踐(K-CV)
在使用svm時,需要采用交叉驗證選擇最佳參數c和g。libsvm中的svmtrain函數內置交叉驗證選項,svmtrain的options如下:
-s svm類型:SVM模型設置類型(默認值為0) 0:C - SVC 1:nu - SVC 2:one - class SVM 3: epsilon - SVR 4: nu - SVR - t 核函數類型:核函數設置類型(默認值為2) 0:線性核函數 u'v 1:多項式核函數(r *u'v + coef0)^degree 2:RBF 核函數 exp( -r|u - v|^2) 3:sigmiod核函數 tanh(r * u'v + coef0) - d degree:核函數中的 degree 參數設置(針對多項式核函數,默認值為3) - g r(gama):核函數中的gama參數設置(針對多項式/sigmoid 核函數/RBF/,默認值為屬性數目的倒數) - r coef0:核函數中的coef0參數設置(針對多項式/sigmoid核函數,默認值為0) - c cost:設置 C - SVC,epsilon - SVR 和 nu - SVR的參數(默認值為1) - n nu:設置 nu-SVC ,one - class SVM 和 nu - SVR的參數 - p epsilon:設置 epsilon - SVR 中損失函數的值(默認值為0.1) - m cachesize:設置 cache 內存大小,以 MB 為單位(默認值為100) - e eps:設置允許的終止閾值(默認值為0.001) - h shrinking:是否使用啟發式,0或1(默認值為1) - wi weight:設置第幾類的參數 C 為 weight * C(對於 C - SVC 中的 C,默認值為1) - v n:n - fold 交互檢驗模式,n為折數,必須大於等於2
其中,-v 隨機地將數據分為n部分,並計算交互檢驗准確度和均方根誤差。大致實現代碼如下:

% Cross_Validation % K-fold cross validation % Author Ethan % Date 2020/4/10 % Version 1.0 fprintf('Beginning crossvalidation\n') crossval_start = tic; %best_accuracy = 0; best_cv = 0; best_c = 0; best_g = 0; k = 3; % number of folds for log2c = -5:5 for log2g = -5:5 cmd = ['-t 0','-v ',num2str(k),'-c ',num2str(2^log2c),'-g ',num2str(2^log2g)]; cv = svmtrain(labels,train_matrix,cmd); if cv >= best_cv best_cv = cv; best_c = 2^log2c; best_g = 2^log2g; end end end crossval_elapsed = toc(crossval_start); fprintf('SVM crosvalidation done in: %f seconds.\n',crossval_elapsed); fprintf('Best crossval reached: %d, with cost=%d\n\n', best_cv, best_c); %svm_params = ['-t ',num2str(0) ,' -c ', num2str(best_c),' -b 1'];