機器學習筆記:sklearn交叉驗證之KFold與StratifiedKFold


一、交叉驗證

機器學習中常用交叉驗證函數:KFoldStratifiedKFold

方法導入:

from sklearn.model_selection import KFold, StratifiedKFold
  • StratifiedKFold:采用分層划分的方法(分層隨機抽樣思想),驗證集中不同類別占比與原始樣本的比例一致,划分時需傳入標簽特征
  • KFold:默認隨機划分訓練集、驗證集

二、KFold交叉驗證

1.使用語法

sklearn.model_selection.KFold(n_splits=3, # 最少2折
                             shuffle=False, # 是否打亂
                             random_state=None)

2.實操

  • get_n_splits -- 返回折數
  • split -- 切分
import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold

X = np.array([[1,2], [3,4], [5,6], [7,8]])
y = np.array([1,2,3,4])
kf = KFold(n_splits=2)
kf.get_n_splits() # 2
print(kf) # KFold(n_splits=2, random_state=None, shuffle=False)

# 此處的split只需傳入數據,不需要傳入標簽
for train_index, test_index in kf.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
'''
TRAIN: [2 3] TEST: [0 1]
TRAIN: [0 1] TEST: [2 3]
'''

三、StratifiedKFold交叉驗證

1.使用語法

sklearn.model_selection.StratifiedKFold(n_splits=3, # 同KFold參數
                                       shuffle=False,
                                       random_state=None)

2.實操

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold

X = np.array([[1,2], [3,4], [5,6], [7,8]])
y = np.array([1,0,0,1])
skf = StratifiedKFold(n_splits=2)
skf.get_n_splits() # 2
print(skf) # StratifiedKFold(n_splits=2, random_state=None, shuffle=False)

# 同時傳入數據集和標簽
for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

注意:拆分的折數必須大於等於標簽類別,否則報錯:

ValueError: n_splits=2 cannot be greater than the number of members in each class.

參考鏈接:sklearn.model_selection.KFold

參考鏈接:sklearn.model_selection.StratifiedKFold

參考鏈接:python sklearn中KFold與StratifiedKFold


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM