這個函數的作用為:對於不同大小的訓練集,確定交叉驗證訓練和測試的分數。一個交叉驗證發生器將整個數據集分割k次,分割成訓練集和測試集。不同大小的訓練集的子集將會被用來訓練評估器並且對於每一個大小的訓練子集都會產生一個分數,然后測試集的分數也會計算。然后,對於每一個訓練子集,運行k次之后的所有這些分數將會被平均
sklearn.model_selection.learning_curve(estimator, X, y, *, groups=None, train_sizes=array([0.1, 0.33, 0.55, 0.78, 1. ]),
cv=None, scoring=None, exploit_incremental_learning=False, n_jobs=None, pre_dispatch='all', verbose=0, shuffle=False, random_state=None, error_score=nan, return_times=False)
參數:
(1)estimator:基模型(如決策樹、邏輯回歸等)
(2)x:特征值(不包括label),如果不支持df格式,我們就用df.values
(3)y:label 目標值
(4)groups:將數據集拆分為訓練/測試集時使用的樣本的標簽分組
(5)train_sizes:array-like, shape (n_ticks,), dtype float or int:訓練示例的相對或絕對數量,將用於生成學習曲線。如果dtype為float,默認為np.linspace(0.1,1.0,5)
(6)cv:交叉驗證折數,默認的5折交叉驗證,如果基模型是分類器,且y是二分類或者是多分類,這使用StratifiedKFold,其他情況默認使用KFold
后面的就不補充了
返回:
train_sizes_abs:array, shape = (n_unique_ticks,), dtype int:用於生成learning curve的訓練集的樣本數。由於重復的輸入將會被刪除,所以ticks可能會少於n_ticks.
train_scores : array, shape (n_ticks, n_cv_folds):在訓練集上的分數
test_scores : array, shape (n_ticks, n_cv_folds):在測試集上的分數
使用鳶尾花作為例子
import pandas as pd import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import load_iris from sklearn.model_selection import learning_curve from sklearn.linear_model import LogisticRegression # 用於模型預測 iris=load_iris() x=iris.data y=iris.target train_sizes, train_scores, test_scores =\ learning_curve(estimator= LogisticRegression(random_state=1), X=x, y=y, train_sizes=np.linspace(0.5, 1.0, 5), #在0.1和1間線性的取10個值 cv=10, n_jobs=1) train_sizes, train_scores, test_scores train_mean = np.mean(train_scores, axis=1) train_std = np.std(train_scores, axis=1) test_mean = np.mean(test_scores, axis=1) test_std = np.std(test_scores, axis=1) plt.plot(train_sizes, train_mean, color='blue', marker='o', markersize=5, label='training accuracy') plt.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue') plt.plot(train_sizes, test_mean, color='green', linestyle='--', marker='s', markersize=5, label='validation accuracy') plt.fill_between(train_sizes, test_mean + test_std, test_mean - test_std, alpha=0.15, color='green') plt.grid() plt.xlabel('Number of training samples') plt.ylabel('Accuracy') plt.legend(loc='lower right') plt.ylim([0.6, 1.0]) plt.tight_layout() plt.show()