python 拆分多類別數據集


原數據集形式,收集的數據來源包括兩個folder, 數據分為三類(class1-3)

 

 希望得到的數據集形式: 將數據集拆分為train和test兩部分,每部分都包含所有類別。

 

完整代碼(已包含注釋,自測可用,參考文獻:數據集划分、label生成及按label將圖片分類到不同文件夾):

 1 import os  2 # import cv2
 3 import random  4 import sys  5 from random import randint  6 import shutil  7 
 8 def fileExist(path1):  9     if os.path.exists(path1): 10         return
11     else: 12         try: 13             os.mkdir(path1)  # 創建單層文件夾
14         except Exception as e: 15             os.makedirs(path1)  # 創建多層文件夾
16 
17 
18 def split_dataset(root_path, new_path, ratio=0.7):  # root: folder1: new_path: dataset1/folder1 按0.7的比例拆分,也可按其他比例
19     folder_list = os.listdir(root_path)  # folder1/[class1,class2...]
20     for folder in folder_list:  # class1
21         train_path = os.path.join(new_path, "train", str(folder)) 22         test_path = os.path.join(new_path, "test", str(folder)) 23         origin_path = os.path.join(root_path, str(folder)) 24         img_list = os.listdir(origin_path) 25 
26         img_num = len(img_list) 27         train_num = int(img_num * ratio) 28         train_sample = random.sample(img_list, train_num) 29         test_sample = list(set(img_list)-set(train_sample)) 30 
31         for item in train_sample: 32             src_new = os.path.join(origin_path, str(item)) 33             dst_new = os.path.join(train_path, str(item)) 34             shutil.copy(src=src_new, dst = dst_new) 35         for item in test_sample: 36             src_new = os.path.join(origin_path, str(item)) 37             dst_new = os.path.join(test_path, str(item)) 38             shutil.copy(src=src_new, dst=dst_new) 39 
40 
41 if __name__ == '__main__': 42     root_path = "dataset"
43     new_path = "dataset1"
44 
45     # 創建文件夾
46     for domain in os.listdir(root_path): 47         domain_path = os.path.join(root_path, str(domain)) 48         domain_new_path = os.path.join(new_path, str(domain)) 49         for folder in os.listdir(domain_path):  # class1
50             train_path = os.path.join(domain_new_path, "train", str(folder)) 51             test_path = os.path.join(domain_new_path, "test", str(folder)) 52  fileExist(train_path) 53  fileExist(test_path) 54 
55     # 拆分數據集到新的路徑
56     for domain in os.listdir(root_path): 57         domain_path = os.path.join(root_path, str(domain)) 58         domain_new_path = os.path.join(new_path, str(domain)) 59         split_dataset(domain_path,domain_new_path

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM