Decision_function:scores,predict以及其他


機器學習的評估

PR曲線用於positive類數據占比比較小,或者你更加在意false postion(相比於false negative);其他情況采用ROC曲線;比如Demo中手寫體5的判斷,因為只有少量5,所以從ROC上面來看分類效果不錯,但是從PR曲線可以看到分類器效果不佳。

 

y_scores = sgd_clf.decision_function([some_digit])

decision_function代表的是參數實例到各個類所代表的超平面的距離;在梯度下滑里面特有的(隨機森林里面沒有decision_function),這個返回的距離,或者說是分值;后續的對於這個值的利用方式是指定閾值來進行過濾:

>>> y_scores = sgd_clf.decision_function([some_digit])

>>> y_scores

array([ 161855.74572176])

>>> threshold = 0

>>> y_some_digit_pred = (y_scores > threshold)

array([ True], dtype=bool)

 

>>> threshold = 200000

>>> y_some_digit_pred = (y_scores > threshold)

>>> y_some_digit_pred

array([False], dtype=bool)

通過上面例子看到了,通過decision_function可以獲得一種"分值",這個分值的幾何意義就是當前點到超平面(hyperplane)的距離;然后,你可以利用這個分值來和某個閾值做比較(距離的閾值),超過閾值則通過,低於閾值則不通過。再舉一個例子:

>>> sgd_clf.fit(X_train, y_train) # y_train, not y_train_5

>>> sgd_clf.predict([some_digit])

array([ 5.])

some_digit_scores=sdg_clf.decision_function([some_digit])

some_digit_scores

array([[-311402.62954431, -363517.28355739, -446449.5306454 ,

-183226.61023518, -414337.15339485, 161855.74572176,

-452576.39616343, -471957.14962573, -518542.33997148,

-536774.63961222]])

 

sgd_clf.fit(X_train, y_train)這個梯度下降算法學習的對象是說有手寫訓練樣本以及0-9的分類標簽,基於學習的模型調用decision_function之后,獲取是[some_digit]所有的標簽到超平面的距離,其中只有5是正值,所以如果調用predict的話返回的就是5。但是,如果我們訓練的分類器是二元分類器(True,false),那么情況又不同:

y_train_5 =(y_train==5)

>>> sgd_clf.fit(X_train,y_train_5) # y_train, not y_train_5

>>> sgd_clf.predict([some_digit])

array([ True])

因為y_train_5這個標簽集合只有True和False兩種標簽,所以訓練之后的模型預測的知識True和false;所以到底是二元分類還是多元分類完全取決於訓練的時候的標簽集。

 

predict:用於分類模型的預測分類;

fit:對於線性回歸的模型學習,稱之為"擬合";

y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")

cross_val_predict是交叉獲取分類概率(注意,這里的method參數設置為"predict_proba",代表返回值返回的是預期分類的概率)

參考:

http://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html

 

?這有一個問題其實每太搞懂,就是scores和predict的關系到底什么,cross_val_score的機制和cross_val_predit之間的差別是什么,文中代碼如下:

from sklearn.ensemble import RandomForestClassifier

forest_clf = RandomForestClassifier(random_state=42)

y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,

method="predict_proba")

 

But to plot a ROC curve, you need scores, not probabilities. A simple solution is to

use the positive class's probability as the score:

 

y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class

fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)

 

不過這里的代碼可以看出一些端倪:

from sklearn.ensemble import RandomForestClassifier

forest_clf=RandomForestClassifier(random_state=42)

y_probas_forest=cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")

print y_probas_forest

[[0.9 0.1] [1. 0. ] [1. 0. ] ... [1. 0. ] [0.9 0.1] [1. 0. ]]

y_scores_forest=y_probas_forest[:, 1]

print y_scores_forest

[0.1 0. 0. ... 0. 0.1 0. ]

你可以看到,scores是probas的二維數組的第二維的值。那么問題來了,作為cross_val_predict里面的數據,二維數組中這個二維到底是什么?這個二維數組其實代表的是各個分類的概率,對於二分類而言,就是為negative的概率以及position概率;對於scores其實就是為position的分類信息。那就意味着如果N個分類(classification),那么就是N維數組了。

另外對於森林分類器里面有一個method的參數,例子中值是"predict_proba",這個代表着預測各個分類的概率;他還有很多其他選項:

predict:代表的是預測的分類,就是會挑選概率最大的分類返回;

predict_log_proba:算法和predic_proba是一樣的,但是最后對於結果會取對數運算,目的是放大值,避免在概率的相乘中會產生一些極小值,然后會因為舍入問題導致誤差;另外一些機器算法(比如散度KL)本身就是基於對數運算的。最后,貝葉斯的分類算法需要通過對數運算(log)來實現穩定性;

對於cv=3,代表采用三折交叉驗證,即將數據隨機分為三份(或者盡量保持數據的均勻分布性),每次拿其中的一份來做測試集(另外兩份做訓練集),然后將三次的結果(每個測試樣本各個分類的概率)做一下平均值;

 

參考

https://stackoverflow.com/questions/20335944/why-use-log-probability-estimates-in-gaussiannb-scikit-learn

https://www.reddit.com/r/MLQuestions/comments/5lzv9o/sklearn_why_predict_log_proba/

https://baike.baidu.com/item/%E5%AF%B9%E6%95%B0%E5%85%AC%E5%BC%8F

https://stats.stackexchange.com/questions/329857/what-is-the-difference-between-decision-function-predict-proba-and-predict-fun

https://stackoverflow.com/questions/36543137/whats-the-difference-between-predict-proba-and-decision-function-in-scikit-lear


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM