交叉驗證和網格搜索


一、交叉驗證(Cross Validation)

1. 目的

交叉驗證的目的是為了讓模型評估更加准確可信。

2. 基本思想

基本思想是將原始數據(dataset)進行分組,一部分做為訓練集(train set),另一部分做為驗證集(validation set or test set),首先用訓練集對分類器進行訓練,再利用驗證集來測試訓練得到的模型,以此來作為評價分類器的性能指標。

3. 主要方法

交叉驗證主要有以下三種方法:

  • Holdout驗證
  • K折交叉驗證
  • 留一驗證

3.1 Holdout驗證

將原始數據隨機分為兩組,一組做為訓練集,一組做為驗證集,利用訓練集訓練分類器,然后利用驗證集驗證模型。

3.2 K折交叉驗證(K-fold Cross Validation)

以10折交叉驗證為例,如下圖所示。

步驟如下:

  1. 將數據集平均分成不相交的10個子集
  2. 每一次挑選其中的1份作為測試集,其余的9份作為訓練集進行模型訓練,得到模型的指標
  3. 重復第2步10次,使每個子集都作為1次測試集,得到10個模型的指標
  4. 將10個模型指標取平均值,作為10折交叉驗證的模型的指標

3.3 留一驗證(Leave-One-Out Cross Validation,LOOCV)

留一驗證是K折交叉驗證的特例,假設原始數據有N個樣本,每個樣本單獨作為驗證集,其余的N-1個樣本作為訓練集。此方法主要用於樣本量非常少的情況。

二、網格搜索(Grid Search)

通常情況下,很多超參數需要調節,但是手動過程繁雜,所以需要對模型預設幾種超參數組合,每組超參數都采用交叉驗證來進行評估。最后選出最優參數組合建立模型。

sklearn中網格搜索API

	sklearn.model_selection.GridSearchCV(estimator,param_grid,cv)

estimator:估計器對象
param_grid:估計器參數,參數名稱(字符串)作為key,要測試的參數列表作為value的字典,或這樣的字典構成的列表
cv:整形,指定K折交叉驗證
方法:
fit:輸入訓練數據
score:准確率
best_score_:交叉驗證中測試的最好的結果
best_estimator_:交叉驗證中測試的最好的參數模型
best_params_:交叉驗證中測試的最好的參數
cv_results_:每次交叉驗證的結果

簡單示例如下:

knn = KNeighborsClassifier()

param = {"n_neighbors": [3,5,10]}
gscv = GridSearchCV(knn, param_grid=param, cv=10)

gscv.fit(x_train, y_train)

print(gscv.score(x_test, y_test))
print(gscv.best_score_)
print(gscv.best_estimator_)
print(gscv.best_params_)
print(pd.DataFrame(gscv.cv_results_).T)

到不了的地方都叫做遠方,回不去的世界都叫做家鄉,我一直向往的卻是比遠更遠的地方。——《幽靈公主》


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM