GridSearchCV


GridSearchCV 簡介:

  GridSearchCV,它存在的意義就是自動調參,只要把參數輸進去,就能給出最優化的結果和參數。但是這個方法適合於小數據集,一旦數據的量級上去了,很難得出結果。這個時候就是需要動腦筋了。數據量比較大的時候可以使用一個快速調優的方法——坐標下降。它其實是一種貪心算法:拿當前對模型影響最大的參數調優,直到最優化;再拿下一個影響最大的參數調優,如此下去,直到所有的參數調整完畢。這個方法的缺點就是可能會調到局部最優而不是全局最優,但是省時間省力,巨大的優勢面前,還是試一試吧,后續可以再拿bagging再優化。回到sklearn里面的GridSearchCV,GridSearchCV用於系統地遍歷多種參數組合,通過交叉驗證確定最佳效果參數。

常用參數解讀:

estimator:所使用的分類器,如estimator=RandomForestClassifier(min_samples_split=100,min_samples_leaf=20,max_depth=8,max_features='sqrt',random_state=10), 並且傳入除需要確定最佳的參數之外的其他參數。每一個分類器都需要一個scoring參數,或者score方法。

param_grid:值為字典或者列表,即需要最優化的參數的取值,param_grid =param_test1,param_test1 = {'n_estimators':range(10,71,10)}。

scoring :准確度評價標准,默認None,這時需要使用score函數;或者如scoring='roc_auc',根據所選模型不同,評價准則不同。字符串(函數名),或是可調用對象,需要其函數簽名形如:scorer(estimator, X, y);如果是None,則使用estimator的誤差估計函數。scoring參數選擇如下:

參考地址:http://scikit-learn.org/stable/modules/model_evaluation.html

cv :交叉驗證參數,默認None,使用三折交叉驗證。指定fold數量,默認為3,也可以是yield訓練/測試數據的生成器。
refit :默認為True,程序將會以交叉驗證訓練集得到的最佳參數,重新對所有可用的訓練集與開發集進行,作為最終用於性能評估的最佳模型參數。即在搜索參數結束后,用最佳參數結果再次fit一遍全部數據集。
iid:默認True,為True時,默認為各個樣本fold概率分布一致,誤差估計為所有樣本之和,而非各個fold的平均。
verbose:日志冗長度,int:冗長度,0:不輸出訓練過程,1:偶爾輸出,>1:對每個子模型都輸出。
n_jobs: 並行數,int:個數,-1:跟CPU核數一致, 1:默認值。
pre_dispatch:指定總共分發的並行任務數。當n_jobs大於1時,數據將在每個運行點進行復制,這可能導致OOM,而設置pre_dispatch參數,則可以預先划分總共的job數量,使數據最多被復制pre_dispatch次。

 

常用方法:

grid.fit():運行網格搜索
grid_scores_:給出不同參數情況下的評價結果
best_params_:描述了已取得最佳結果的參數的組合
best_score_:成員提供優化過程期間觀察到的最好的評分

 

使用案例:

param_test1 = { 'max_depth':list(range(3,10,1))
               ,'min_child_weight':list(range(1,6,2))
              }
gsearch1 = GridSearchCV(
    estimator = XGBClassifier(
                                silent=1 ,#設置成1則沒有運行信息輸出,最好是設置為0.是否在運行升級時打印消息。
                                learning_rate= 0.1,# 如同學習率
#                                 max_depth=4,# 構建樹的深度,越大越容易過擬合
                                colsample_bytree=1, # 生成樹時進行的列采樣
                                reg_lambda=4,  # 控制模型復雜度的權重值的L2正則化項參數,參數越大,模型越不容易過擬合。
                                objective= 'binary:logistic', #多分類的問題 指定學習任務和相應的學習目標
                                n_estimators=70, # 樹的個數
                                ),
                       param_grid = param_test1, scoring='roc_auc', iid=False, cv=5)
gsearch1.fit(X,y)
gsearch1.grid_scores_, gsearch1.best_params_, gsearch1.best_score_

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM