sklearn.model_selection.RandomizedSearchCV隨機搜索超參數


GridSearchCV可以保證在指定的參數范圍內找到精度最高的參數,但是這也是網格搜索的缺陷所在,它要求遍歷所有可能參數的組合,在面對大數據集和多參數的情況下,非常耗時。這也是我通常不會使用GridSearchCV的原因,一般會采用后一種RandomizedSearchCV隨機參數搜索的方法

RandomizedSearchCV的使用方法其實是和GridSearchCV一致的,但它以隨機在參數空間中采樣的方式代替了GridSearchCV對於參數的網格搜索,在對於有連續變量的參數時,RandomizedSearchCV會將其當作一個分布進行采樣這是網格搜索做不到的,它的搜索能力取決於設定的n_iter參數

函數用法:

class sklearn.model_selection.RandomizedSearchCV(estimator, param_distributions, *, n_iter=10, 
scoring=None, n_jobs=None, iid='deprecated', refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
random_state=None, error_score=nan, return_train_score=False)

參數詳解:

estimator:估計器

param_distributions 字典或字典列表:參數字典,key是參數名,values是參數范圍

n_iter int,默認= 10:抽取樣本是訓練次數

更多參數參考:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html

RandomSearchCV是如何"隨機搜索"的

考察其源代碼,其搜索策略如下:
(a)對於搜索范圍是distribution的超參數,根據給定的distribution隨機采樣;
(b)對於搜索范圍是list的超參數,在給定的list中等概率采樣;
(c)對a、b兩步中得到的n_iter組采樣結果,進行遍歷。
(補充)如果給定的搜索范圍均為list,則不放回抽樣n_iter次。

import numpy as np
from scipy.stats import randint as sp_randint
from sklearn.model_selection import RandomizedSearchCV
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier

# 載入數據
digits = load_digits()
X, y = digits.data, digits.target

# 建立一個分類器或者回歸器
clf = RandomForestClassifier(n_estimators=20)

# 給定參數搜索范圍:list or distribution
param_dist = {"max_depth": [3, None],                     #給定list
              "max_features": sp_randint(1, 11),          #給定distribution
              "min_samples_split": sp_randint(2, 11),     #給定distribution
              "bootstrap": [True, False],                 #給定list
              "criterion": ["gini", "entropy"]}           #給定list

# 用RandomSearch+CV選取超參數
n_iter_search = 20
random_search = RandomizedSearchCV(clf, param_distributions=param_dist,
                                   n_iter=n_iter_search, cv=5, iid=False)
clf=random_search.fit(X, y)
clf.best_params_ 
{'bootstrap': False,
 'criterion': 'entropy',
 'max_depth': None,
 'max_features': 9,
 'min_samples_split': 8}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM