在機器學習或者深度學習中,我們常常碰到一個問題是數據集的切分。比如在一個比賽中,舉辦方給我們的只是一個帶標注的訓練集和不帶標注的測試集。其中訓練集是用於訓練,而測試集用於已訓練模型上跑出一個結果,然后提交,然后舉辦方驗證結果給出一個分數。但是我們在訓練過程中,可能會出現過擬合等問題,會面臨着算法和模型的選擇,此時,驗證集就顯得很重要。通常,如果數據量充足,我們會從訓練集中划分出一定比例的數據來作為驗證集。
每次划分數據集都手動寫一個腳本,重復性太高,因此將此簡單的腳本放到自己的博客。代碼如下:
1 import random 2 3 def split(full_list,shuffle=False,ratio=0.2): 4 n_total = len(full_list) 5 offset = int(n_total * ratio) 6 if n_total==0 or offset<1: 7 return [],full_list 8 if shuffle: 9 random.shuffle(full_list) 10 sublist_1 = full_list[:offset] 11 sublist_2 = full_list[offset:] 12 return sublist_1,sublist_2 13 14 15 if __name__ == "__main__": 16 li = range(5) 17 sublist_1,sublist_2 = split(li,shuffle=True,ratio=0.2) 18 19 print sublist_1,len(sublist_1) 20 print sublist_2,len(sublist_2)
其中,main為測試代碼。假如訓練集給出的是一個文件,我們先將文件讀到列表中,然后再調用split。