sklearn中cross_val_score、cross_val_predict的用法比較


https://blog.csdn.net/cxx654/article/details/104813830

 

sklearn中cross_val_score、cross_val_predict的用法比較

程大海 2020-03-12 11:02:36 8444 收藏 21
分類專欄: python編程 機器學習 文章標簽: 機器學習 sklearn cross_val_score 交叉驗證
版權

python編程
同時被 2 個專欄收錄
49 篇文章0 訂閱
訂閱專欄

機器學習
33 篇文章0 訂閱
訂閱專欄
交叉驗證的概念,直接粘貼scikit-learn官網的定義:

 

 

scikit-learn中計算交叉驗證的函數:

cross_val_score:得到K折驗證中每一折的得分,K個得分取平均值就是模型的平均性能

cross_val_predict:得到經過K折交叉驗證計算得到的每個訓練驗證的輸出預測

方法:

cross_val_score:分別在K-1折上訓練模型,在余下的1折上驗證模型,並保存余下1折中的預測得分

cross_val_predict:分別在K-1上訓練模型,在余下的1折上驗證模型,並將余下1折中樣本的預測輸出作為最終輸出結果的一部分

 

結論:

cross_val_score計算得到的平均性能可以作為模型的泛化性能參考

cross_val_predict計算得到的樣本預測輸出不能作為模型的泛化性能參考

from sklearn import datasets
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
import numpy as np
from sklearn.tree import DecisionTreeClassifier
 
# 加載鳶尾花數據集
iris = datasets.load_iris()
iris_train = iris.data
iris_target = iris.target
print(iris_train.shape)
print(iris_target.shape)
(150, 4)
(150,)
 
# 構建決策樹分類模型
tree_clf = DecisionTreeClassifier()
tree_clf.fit(iris_train, iris_target)
tree_predict = tree_clf.predict(iris_train)
​
# 計算決策樹分類模型的准確率
from sklearn.metrics import accuracy_score
print("Accuracy:", accuracy_score(iris_target, tree_predict))
Accuracy: 1.0
 
# 交叉驗證cross_val_score輸出每一折上的准確率
from sklearn.model_selection import cross_val_predict, cross_val_score, cross_validate
tree_scores = cross_val_score(tree_clf, iris_train, iris_target, cv=3)
print(tree_scores)
[0.98039216 0.92156863 1.        ]
 
# 交叉驗證cross_val_predict輸出每個樣本的預測結果
tree_predict = cross_val_predict(tree_clf, iris_train, iris_target, cv=3)
print(tree_predict)
print(len(tree_predict))
print(accuracy_score(iris_target, tree_predict))
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1
 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2
 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2
 2 2]
150
0.96
 
print(tree_clf.predict(iris_train))
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
 
# 交叉驗證cross_validate對cross_val_score結果進行包裝,並包含fit的時間等信息
tree_val = cross_validate(tree_clf, iris_train, iris_target, cv=3)
print(tree_val)
{'fit_time': array([0., 0., 0.]), 'score_time': array([0., 0., 0.]), 'test_score': array([0.98039216, 0.92156863, 0.97916667])}
 
​
 
​

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM