轉自 :http://blog.csdn.net/aliceyangxi1987/article/details/73598857
學習曲線是什么?
學習曲線就是通過畫出不同訓練集大小時訓練集和交叉驗證的准確率,可以看到模型在新數據上的表現,進而來判斷模型是否方差偏高或偏差過高,以及增大訓練集是否可以減小過擬合。
怎么解讀?
當訓練集和測試集的誤差收斂但卻很高時,為高偏差。
左上角的偏差很高,訓練集和驗證集的准確率都很低,很可能是欠擬合。
我們可以增加模型參數,比如,構建更多的特征,減小正則項。
此時通過增加數據量是不起作用的。
當訓練集和測試集的誤差之間有大的差距時,為高方差。
當訓練集的准確率比其他獨立數據集上的測試結果的准確率要高時,一般都是過擬合。
右上角方差很高,訓練集和驗證集的准確率相差太多,應該是過擬合。
我們可以增大訓練集,降低模型復雜度,增大正則項,或者通過特征選擇減少特征數。
理想情況是是找到偏差和方差都很小的情況,即收斂且誤差較小。
怎么畫?
在畫學習曲線時,橫軸為訓練樣本的數量,縱軸為准確率。
例如同樣的問題,左圖為我們用 naive Bayes 分類器時,效果不太好,分數大約收斂在 0.85,此時增加數據對效果沒有幫助。
右圖為 SVM(RBF kernel),訓練集的准確率很高,驗證集的也隨着數據量增加而增加,不過因為訓練集的還是高於驗證集的,有點過擬合,所以還是需要增加數據量,這時增加數據會對效果有幫助。
上圖的代碼如下:
模型這里用 GaussianNB 和 SVC 做比較,
模型選擇方法中需要用到 learning_curve 和交叉驗證方法 ShuffleSplit。
import numpy as np import matplotlib.pyplot as plt from sklearn.naive_bayes import GaussianNB from sklearn.svm import SVC from sklearn.datasets import load_digits from sklearn.model_selection import learning_curve from sklearn.model_selection import ShuffleSplit
首先定義畫出學習曲線的方法,
核心就是調用了 sklearn.model_selection 的 learning_curve,
學習曲線返回的是 train_sizes, train_scores, test_scores,
畫訓練集的曲線時,橫軸為 train_sizes, 縱軸為 train_scores_mean,
畫測試集的曲線時,橫軸為 train_sizes, 縱軸為 test_scores_mean:
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None, n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)): ~~~ train_sizes, train_scores, test_scores = learning_curve( estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes) train_scores_mean = np.mean(train_scores, axis=1) test_scores_mean = np.mean(test_scores, axis=1) ~~~
在調用 plot_learning_curve 時,首先定義交叉驗證 cv 和學習模型 estimator。
這里交叉驗證用的是 ShuffleSplit, 它首先將樣例打散,並隨機取 20% 的數據作為測試集,這樣取出 100 次,最后返回的是 train_index, test_index,就知道哪些數據是 train,哪些數據是 test。
estimator 用的是 GaussianNB,對應左圖:
cv = ShuffleSplit(n_splits=100, test_size=0.2, random_state=0) estimator = GaussianNB() plot_learning_curve(estimator, title, X, y, ylim=(0.7, 1.01), cv=cv, n_jobs=4)
再看 estimator 是 SVC 的時候,對應右圖:
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0) estimator = SVC(gamma=0.001) plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=4)