https://blog.csdn.net/cxx654/article/details/104813830
sklearn中cross_val_score、cross_val_predict的用法比較
程大海 2020-03-12 11:02:36 8444 收藏 21
分類專欄: python編程 機器學習 文章標簽: 機器學習 sklearn cross_val_score 交叉驗證
版權
python編程
同時被 2 個專欄收錄
49 篇文章0 訂閱
訂閱專欄
機器學習
33 篇文章0 訂閱
訂閱專欄
交叉驗證的概念,直接粘貼scikit-learn官網的定義:
scikit-learn中計算交叉驗證的函數:
cross_val_score:得到K折驗證中每一折的得分,K個得分取平均值就是模型的平均性能
cross_val_predict:得到經過K折交叉驗證計算得到的每個訓練驗證的輸出預測
方法:
cross_val_score:分別在K-1折上訓練模型,在余下的1折上驗證模型,並保存余下1折中的預測得分
cross_val_predict:分別在K-1上訓練模型,在余下的1折上驗證模型,並將余下1折中樣本的預測輸出作為最終輸出結果的一部分
結論:
cross_val_score計算得到的平均性能可以作為模型的泛化性能參考
cross_val_predict計算得到的樣本預測輸出不能作為模型的泛化性能參考
from sklearn import datasets import numpy as np from sklearn.tree import DecisionTreeClassifier from sklearn import datasets import numpy as np from sklearn.tree import DecisionTreeClassifier # 加載鳶尾花數據集 iris = datasets.load_iris() iris_train = iris.data iris_target = iris.target print(iris_train.shape) print(iris_target.shape) (150, 4) (150,) # 構建決策樹分類模型 tree_clf = DecisionTreeClassifier() tree_clf.fit(iris_train, iris_target) tree_predict = tree_clf.predict(iris_train) # 計算決策樹分類模型的准確率 from sklearn.metrics import accuracy_score print("Accuracy:", accuracy_score(iris_target, tree_predict)) Accuracy: 1.0 # 交叉驗證cross_val_score輸出每一折上的准確率 from sklearn.model_selection import cross_val_predict, cross_val_score, cross_validate tree_scores = cross_val_score(tree_clf, iris_train, iris_target, cv=3) print(tree_scores) [0.98039216 0.92156863 1. ] # 交叉驗證cross_val_predict輸出每個樣本的預測結果 tree_predict = cross_val_predict(tree_clf, iris_train, iris_target, cv=3) print(tree_predict) print(len(tree_predict)) print(accuracy_score(iris_target, tree_predict)) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2] 150 0.96 print(tree_clf.predict(iris_train)) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] # 交叉驗證cross_validate對cross_val_score結果進行包裝,並包含fit的時間等信息 tree_val = cross_validate(tree_clf, iris_train, iris_target, cv=3) print(tree_val) {'fit_time': array([0., 0., 0.]), 'score_time': array([0., 0., 0.]), 'test_score': array([0.98039216, 0.92156863, 0.97916667])}