1. 引言
當我們跑機器學習程序時,尤其是調節網絡參數時,通常待調節的參數有很多,參數之間的組合更是繁復。依照注意力>時間>金錢的原則,人力手動調節注意力成本太高,非常不值得。For循環或類似於for循環的方法受限於太過分明的層次,不夠簡潔與靈活,注意力成本高,易出錯。本文介紹sklearn模塊的GridSearchCV模塊,能夠在指定的范圍內自動搜索具有不同超參數的不同模型組合,有效解放注意力。
2. GridSearchCV模塊簡介
這個模塊是sklearn模塊的子模塊,導入方法非常簡單
from sklearn.model_selection import GridSearchCV
函數原型:
class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True)
其中cv可以是整數或者交叉驗證生成器或一個可迭代器,cv參數對應的4種輸入列舉如下:
- None:默認參數,函數會使用默認的3折交叉驗證
- 整數k:k折交叉驗證。對於分類任務,使用StratifiedKFold(類別平衡,每類的訓練集占比一樣多,具體可以查看官方文檔)。對於其他任務,使用KFold
- 交叉驗證生成器:得自己寫生成器,頭疼,略
- 可以生成訓練集與測試集的迭代器:同上,略
3. 分析結果自動保存
逗號分隔值(Comma-Separated Values,CSV,有時也稱為字符分隔值,因為分隔字符也可以不是逗號),其文件以純文本形式存儲表格數據(數字和文本)。純文本意味着該文件是一個,不含必須像二進制數字那樣被解讀的數據。CSV文件由任意數目的記錄組成,記錄間以某種換行符分隔;每條記錄由字段組成,字段間的分隔符是其它字符或字符串,最常見的是逗號或制表符。通常,所有記錄都有完全相同的字段序列。
CSV文件有個突出的優點,可以用excel等軟件打開,比起記事本和matlab、python等編程語言界面,便於查看、制作報告、后期整理等。
GridSearchCV模塊中,不同超參數的組合方式及其計算結果以字典的形式保存在 clf.cv_results_中,python的pandas模塊提供了高效整理數據的方法,只需要3行代碼即可解決問題。
cv_result = pd.DataFrame.from_dict(clf.cv_results_) with open('cv_result.csv','w') as f: cv_result.to_csv(f)
4. 完整例程
代碼清晰易懂,無須解釋。https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search
1 import pandas as pd 2 from sklearn import svm, datasets 3 from sklearn.model_selection import GridSearchCV 4 from sklearn.metrics import classification_report 5 6 iris = datasets.load_iris() 7 parameters = {'kernel':('linear', 'rbf'), 'C':[1, 2, 4], 'gamma':[0.125, 0.25, 0.5 ,1, 2, 4]} 8 svr = svm.SVC() 9 clf = GridSearchCV(svr, parameters, n_jobs=-1) 10 clf.fit(iris.data, iris.target) 11 cv_result = pd.DataFrame.from_dict(clf.cv_results_) 12 with open('cv_result.csv','w') as f: 13 cv_result.to_csv(f) 14 15 print('The parameters of the best model are: ') 16 print(clf.best_params_) 17 18 y_pred = clf.predict(iris.data) 19 print(classification_report(y_true=iris.target, y_pred=y_pred))
5. 相關資料
- sklearn.model_selection.GridSearchCV模塊主頁: http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
- pandas.DataFrame模塊主頁:http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html
- 本文例程 https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search
6.未來展望
當前的工作局限於算法超參數搜索,還沒有結合預處理方式自動搜索、不同算法之間自動搜索、不同深度學習模型自動搜索等。如何利用pipeline、keras、tf等模塊,實現整個環節的自動搜索,是下一步學習與總結的方向。