Python超參數自動搜索模塊GridSearchCV上手


1. 引言

當我們跑機器學習程序時,尤其是調節網絡參數時,通常待調節的參數有很多,參數之間的組合更是繁復。依照注意力>時間>金錢的原則,人力手動調節注意力成本太高,非常不值得。For循環或類似於for循環的方法受限於太過分明的層次,不夠簡潔與靈活,注意力成本高,易出錯。本文介紹sklearn模塊的GridSearchCV模塊,能夠在指定的范圍內自動搜索具有不同超參數的不同模型組合,有效解放注意力。

2. GridSearchCV模塊簡介

這個模塊是sklearn模塊的子模塊,導入方法非常簡單

from sklearn.model_selection import GridSearchCV

函數原型:

class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True)

其中cv可以是整數或者交叉驗證生成器或一個可迭代器,cv參數對應的4種輸入列舉如下:

  1. None:默認參數,函數會使用默認的3折交叉驗證
  2. 整數k:k折交叉驗證。對於分類任務,使用StratifiedKFold(類別平衡,每類的訓練集占比一樣多,具體可以查看官方文檔)。對於其他任務,使用KFold
  3. 交叉驗證生成器:得自己寫生成器,頭疼,略
  4. 可以生成訓練集與測試集的迭代器:同上,略

3. 分析結果自動保存

逗號分隔值(Comma-Separated Values,CSV,有時也稱為字符分隔值,因為分隔字符也可以不是逗號),其文件以純文本形式存儲表格數據(數字和文本)。純文本意味着該文件是一個,不含必須像二進制數字那樣被解讀的數據。CSV文件由任意數目的記錄組成,記錄間以某種換行符分隔;每條記錄由字段組成,字段間的分隔符是其它字符或字符串,最常見的是逗號或制表符。通常,所有記錄都有完全相同的字段序列。

CSV文件有個突出的優點,可以用excel等軟件打開,比起記事本和matlab、python等編程語言界面,便於查看、制作報告、后期整理等。

GridSearchCV模塊中,不同超參數的組合方式及其計算結果以字典的形式保存在 clf.cv_results_中,python的pandas模塊提供了高效整理數據的方法,只需要3行代碼即可解決問題。

cv_result = pd.DataFrame.from_dict(clf.cv_results_)
with open('cv_result.csv','w') as f:
  cv_result.to_csv(f)

4. 完整例程

代碼清晰易懂,無須解釋。https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search

 1 import pandas as pd
 2 from sklearn import svm, datasets
 3 from sklearn.model_selection import GridSearchCV
 4 from sklearn.metrics import classification_report
 5 
 6 iris = datasets.load_iris()
 7 parameters = {'kernel':('linear', 'rbf'), 'C':[1, 2, 4], 'gamma':[0.125, 0.25, 0.5 ,1, 2, 4]}
 8 svr = svm.SVC()
 9 clf = GridSearchCV(svr, parameters, n_jobs=-1)
10 clf.fit(iris.data, iris.target)
11 cv_result = pd.DataFrame.from_dict(clf.cv_results_)
12 with open('cv_result.csv','w') as f:
13     cv_result.to_csv(f)
14     
15 print('The parameters of the best model are: ')
16 print(clf.best_params_)
17 
18 y_pred = clf.predict(iris.data)
19 print(classification_report(y_true=iris.target, y_pred=y_pred))

5. 相關資料

  1. sklearn.model_selection.GridSearchCV模塊主頁: http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
  2. pandas.DataFrame模塊主頁:http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html
  3. 本文例程 https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search

6.未來展望

   當前的工作局限於算法超參數搜索,還沒有結合預處理方式自動搜索、不同算法之間自動搜索、不同深度學習模型自動搜索等。如何利用pipeline、keras、tf等模塊,實現整個環節的自動搜索,是下一步學習與總結的方向。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM