Scikit-Learn 源碼研讀 (第二期)基類的實現細節


sklearn項目可以看成一棵大樹,各種estimator是果實,而支撐這些估計器的主干,是為數不多的幾個基類。常見的幾個類有BaseEstimator、BaseSGD、ClassifierMixin、RegressorMixin,等等。

官方文檔的API參考頁面列出了主要的API接口,我們看下Base類

本期我們只研究BaseEstimator、ClassifierMixin、RegressorMixin、TransformerMixin。BaseSGD是一個比較大的話題,需要單獨開一期來仔細研究。

BaseEstimator

最底層的就是BaseEstimator類。主要暴露兩個方法:set_paramsget_params.

get_params

這個方法旨在獲取對象的參數,返回對象默認是{參數:參數值}的鍵值對。如果將get_params的參數deep設置為True,還會返回(如果有的話)子對象(它們是估計器)。下面我們來仔細看一下這個方法的實現細節:

為了節約篇幅,我會將不重要的注釋略去,以后都是這樣處理,不再贅述,除非特殊說明。

(1)
函數體中主要就是getattr方法,語法:getattr(對象,要檢索的屬性[,如果屬性不存在則返回的值])。Line200~208的任務是判斷self(一般就是估計器的實例)是否含有key這個參數,如果有就返回它的參數值,否則人為設置為None。

為什么要寫這么復雜呢? 其實可以直接寫作 value = getattr(self, key, None),有點迷~

(2)
再來看Line209~212,如果用戶設置了deep=True,並且value對象實現了get_params(說明value對象是一個子對象,即估計器,否則普通的參數是不會再次實現get_params方法的),則提取參數字典的鍵值對,並且寫入字典。整個函數最后返回的也是字典。

(3)
我們先快速的看一下這個方法具體是怎么使用的,然后再繼續追蹤源碼的實現。

from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(random_state=0)
X = [[ 1,  2,  3],  # 2 samples, 3 features
     [11, 12, 13]]
y = [0, 1]  # classes of each sample
clf.fit(X, y)

簡單的實例化一個隨機森林分類器的對象,我們看下對它調用get_params會返回什么:

clf.get_params()

{'bootstrap': True,
 'class_weight': None,
 'criterion': 'gini',
 'max_depth': None,
 'max_features': 'auto',
 'max_leaf_nodes': None,
 'min_impurity_decrease': 0.0,
 'min_impurity_split': None,
 'min_samples_leaf': 1,
 'min_samples_split': 2,
 'min_weight_fraction_leaf': 0.0,
 'n_estimators': 10,
 'n_jobs': None,
 'oob_score': False,
 'random_state': 0,
 'verbose': 0,
 'warm_start': False}

很明顯,這就是這個隨機森林分類器的默認參數方案。

(4)
我們注意到Line199這行,使用了另一個方法 for key in self._get_param_names():,現在研究該函數

這里贅述一下,在sklearn這種大型的Python項目中,很多暴露出去的方法,其實質只是一個殼子,你可以理解為它是在搬運別人做的東西,只是美化包裝一下交給調用者。例如get_params方法,它並沒有真的獲取到估計器實例的參數,因為_get_param_names在幫它干這個活兒。

@classmethod這個裝飾器直接告訴我們,該方法的適用對象是類自身,而非實例對象。

這個函數有很多檢查事項,真正獲取參數的是 inspect.signature(init).parameters.values(),最后獲取列表中每個對象的name屬性。

set_params

這個方法作用是設置參數。正常來說,我們在初始化估計器的時候定制化參數,但是也有臨時修改參數的需求,這時可以手工調用set_params方法。但是更多的還是由繼承BaseEstimator的類來調用這個方法。

具體地,我們看下實現細節:

這個方案支持處理嵌套字典,但是我們不去糾纏這么瑣碎,直接看到L251,setattr(self, key, value),對估計器的key屬性設置一個新的值。

應用的實例:

ClassifierMixin

Mixin表示混入類,可以簡單地理解為給其他的類增加一些額外的方法。Sklearn的分類、回歸混入類只實現了score方法,任何繼承它們的類需要自己去實現fitpredict等其他方法。

關於混入類,簡單的說就是一個父類,但是和普通的類有點不同,它需要指明元對象,_estimator_type。這里不再展開論述,感興趣的讀者請閱讀這篇討論 What is a mixin, and why are they useful?

可以看到,這個混入類的實現非常簡單,求預測值和真實值的准確率,返回值是一個浮點數。注意預測值來自self.predict(),所以繼承混入類的類必須自己實現predict方法,否則引發錯誤。后面不再重復強調該細節。

再次的,分類任務的混入類又是在搬運其它函數的勞動成果,那我們就來研究一下accuracy_score的實現細節

為簡潔起見,我們先忽略L185~189之間的代碼,后面會有專門研究分類任務的度量方法的文章,在那里我們再仔細研究它。直接看L191,y_ture == y_pred,這是一個簡單的寫法,精妙在於避免了for循環,快速的檢查兩個對象之間每一個元素是否相等並且返回True/False。L193對score結果做一層包裝。

  • L116:如果設置了normalize參數為True,則對score列表取平均值,就是預測正確的樣本個數/總體個數=預測准確率
  • L118:如果有權重,則按照權重對各個樣本的得分進行加權,作為最終的預測准確率
  • L121:如果沒有上述兩種設置,則直接返回預測正確的樣本的個數。注意:sklearn默認的score方法返回預測准確率,而非預測正確的樣本個數。

RegressorMixin

毫不意外地,回歸任務的混入類只實現了score方法,核心數學原理是 \(R^2\) 值。公式是 1-((y_true - y_pred)2)/((y_true - y_true_mean)2),直觀上看,這個值是衡量預測值與真實值的偏離度與真實值自身偏離度的一個比值。 \(R^2\)最大為1,表示預測完全准確,值為0時表示模型沒有任何預測能力。

score方法調用了metrics模塊的r2_score方法,返回值是浮點數。我們來研究下r2_score,這個函數是目前為止我們看過的最復雜的一個。因此,我們一塊一塊來研究。

檢查傳入的對象

(1)檢查傳入對象的長度
L577調用check_consistent_length檢查輸入標簽、輸出標簽、權重是不是有相同的長度。檢查的方法也很簡單,對每個對象計算長度,然后取不同的長度值有多少個,如果超過1個,說明幾個對象之間的長度不一,則引發一個錯誤來警告。

(2)檢查傳入的參數是否合法
L575調用_check_reg_targets方法,旨在檢查傳入參數是否合法。

這個函數略長,但是大致做了以下幾件事:

  • L83~95都是在做檢查和格式轉換。
  • L97~114檢查輸入multioutputy_true是否吻合,即真實的標簽數組的維度如果是1的話,顯然設置multioutput這個參數非None是不合法的。並且當真實標簽數組的維度大於1的時候,若其維度和multioutput不同時也會引發錯誤以告警。
  • L115根據y_true的維度決定標簽是哪種類型,分為:連續型和多類輸出的連續型。
    注意:multioutput可以是字符串,也可以是一個數組,還可以是None值(考慮到向下兼容),因此這個參數非常靈活。后面研究具體算法時遇到了會再次提及,此處不作過多糾纏。

檢查樣本數和權重系數

繼續看r2_score的實現:

(3)L597~582檢查預測值的樣本數
如果預測值的樣本數不足2個,則引發錯誤告警。因為決定系數(即\(R^2\))要求至少要有2個樣本

(4)L584~588處理權重系數

  • L585調用np.ravel(),把權重數組拉平到一維
  • L586對sample_weights擴維,將一維擴充為二維,二維擴充為三維,以此類推。值得注意的是,np.newaxis放置的位置不同,擴充的方向是不同的,具體看下面這個小例子:
  • L588,如果沒有傳入權重系數,則默認設置為1

實現\(R^2\)的計算細節

(5)構造分子和分母

(6)計算每個樣本的得分

  • L595~596 記錄分母和分子的數組中不為0的索引值(就是非0值所在的位置)
  • L597 記錄分子、分母同時不為0的樣本的索引值。如果對這個寫法不熟悉,這里有個小例子幫助理解:
  • L598~599 創建一個和真實標簽相同長度的全1數組,然后對合法的索引位置計算真實的\(R^2\)值。
  • L603 將分母為0的索引位置的值設置為0,這里設為其他常數也是可以的,對於同一個回歸任務的評價沒有影響。

(7)根據multioutput參數來決定各樣本所得分數的權重

  • L605~607 如果指明raw_values,則輸出每個樣本的分數
  • L608~610 如果指明uniform_average,則avg_weights設置為None,其實就是均勻分布權重
  • L611~612 如果指明variance_weighted,則直接用分母作權重
  • L614~618 處理常量y值或一維數組的情形。如果分母全是0,則:若分子有非0,直接返回1;否則返回0
  • L620 如果multioutput不是字符串,則直接把它作為最后的權重系數

(8)返回得分

return np.average(output_scores, weights=avg_weights)

剛剛說到,指明uniform_average,則avg_weights設置為None。在numpy.average這個方法里,如果權重是None,計算均值就是簡單的mean()函數。

TransformerMixin

這個混入類的實現比較簡單,完全依靠使用它的類自己實現的fit方法和transform方法。但是它會根據是否有標簽,決定是有監督任務還是無監督任務。等后面遇到再具體討論。

補充

我們在研究分類混入類和回歸混入類的時候,都發現有_estimator_type這個變量,它的具體作用就是這里看到的,判斷一個估計器是用於分類任務還是回歸任務的。


如果有任何紕漏差錯,歡迎評論互動。

drawing


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM