數據集的格式如下:
datasets----train文件夾(WA和WKY文件夾,里面分別存放了200張圖片)
----test文件夾(WA和WKY文件夾,里面分別存放了100張圖片)
每一張圖片都有自己的文件名,train中WA的圖片標簽為0,WKY的圖片標簽為1。
1.構建Dataset
1 import os 2 import random 3 import torch 4 from torch.utils.data import Dataset 5 import torchvision 6 import imghdr 7 from PIL import Image 8 import matplotlib.pyplot as plt 9 10 11 class MedicalDataset(Dataset): 12 def __init__(self, root, split, data_ratio=1.0): 13 self.img_list = list() #self.img_list存儲的是所有.jpg文件的絕對路徑名 14 self.cls_list = list() #存儲label索引 15 self.cls_num = dict() #每個類別的樣本個數 16 17 18 classes = ['WA', 'WKY'] 19 for idx, cls in enumerate(classes): 20 img_list = sorted(os.listdir(os.path.join(root, split, cls))) 21 self.cls_num[cls] = len(img_list) 22 for img_fp in img_list: #取出每一個文件名 23 self.img_list.append(os.path.join(root, split, cls, img_fp)) 24 self.cls_list.append(idx) 25 26 if data_ratio < 1.0: 27 shuffled_idxs = list(range(len(self.img_list))) 28 random.shuffle(shuffled_idxs) 29 num_samples = round(data_ratio * len(self.img_list)) 30 img_list = list() 31 cls_list = list() 32 for idx in shuffled_idxs[:num_samples]: 33 img_list.append(self.img_list[idx]) 34 cls_list.append(self.cls_list[idx]) 35 self.img_list = img_list 36 self.cls_list = cls_list 37 38 if split == 'train': 39 self.trans = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 40 torchvision.transforms.RandomCrop(224), 41 torchvision.transforms.RandomHorizontalFlip(), 42 torchvision.transforms.ToTensor(), 43 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 44 [0.229, 0.224, 0.225]) 45 ]) 46 else: 47 self.trans = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 48 torchvision.transforms.CenterCrop(224), 49 torchvision.transforms.ToTensor(), 50 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 51 [0.229, 0.224, 0.225]) 52 ]) 53 54 def _getdata(self): 55 return self.img_list 56 57 def __getitem__(self, index): 58 name = self.img_list[index] 59 img = Image.open(name) 60 img = self.trans(img) 61 label = self.cls_list[index] 62 return img, label #這里必須返回img和label,否則后面取出來的格式不對 63 64 def __len__(self): 65 return len(self.img_list)
如果想查看圖片的話,
1 from torch.utils.data import DataLoader 2 3 dataset = MedicalDataset('datasets/', 'train') 4 print('dataset: ', dataset) 5 print('len= ', dataset.__len__()) # 訓練集總共樣本數:400 6 7 img, label = dataset.__getitem__(-1) 8 print('img.shape= ', img.shape) # torch.Size([3, 224, 224]) 9 print('label= ', label) # 1 10 11 loader = DataLoader(dataset, batch_size=16, shuffle=True) #loader中每次迭代的元素就是item返回的值 12 print(next(iter(loader))[0].shape, next(iter(loader))[1].shape) #torch.Size([16, 3, 224, 224]), torch.Size([16]) 13 14 #顯示一張圖片 15 unloader = torchvision.transforms.ToPILImage() # .ToPILImage() 把tensor或數組轉換成圖像 16 17 def imshow(tensor, title=None): 18 image = tensor.cpu().clone() # we clone the tensor to not do changes on it 19 image = image.squeeze(0) 20 21 image = unloader(image) # tensor轉換成圖像 22 plt.imshow(image) 23 if title is not None: 24 plt.title(title) 25 plt.pause(1) # 只是延遲顯示作用 26 27 plt.figure() 28 imshow(img, title='Image')
2.創建DataLoader
parser.add_argument("--dataset-path", default='./datasets', type=str, help="Path of the trainset.")
1 # 創建數據集 2 train_dataset = MedicalDataset(args.dataset_path, 'train') 3 test_dataset = MedicalDataset(args.dataset_path, 'test') 4 print(len(train_dataset), len(test_dataset)) # 訓練集400,測試集200 5 # 把訓練集分割成訓練集和驗證集,比例為8:2 6 train_size = int(0.8 * len(train_dataset)) 7 val_size = len(train_dataset) - train_size 8 train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size]) 9 10 train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) 11 val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) 12 test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
3.oversample過采樣
假如train中WA和WKY的數據不平衡(eg訓練集中WA有1555張,WKY有496張,驗證集中WA有223張,WKY有70張,測試集中WA有444張,WKY有142張),需要對WKY的訓練集和驗證集進行過采樣(不是單純的重復,使用數據增強),測試集不用管。
1 import os 2 import random 3 import torch 4 from torch.utils.data import Dataset 5 import torchvision 6 from PIL import Image 7 8 class MedicalDataset(Dataset): 9 def __init__(self, root, split, data_ratio=1.0, ret_name=False): 10 assert split in ['train', 'val', 'test'] 11 self.ret_name = ret_name 12 self.cls_to_ind_dict = dict() 13 self.ind_to_cls_dict = list() 14 self.img_list = list() 15 self.cls_list = list() 16 self.cls_num = dict() 17 18 classes = ['WA', 'WKY'] 19 if split=='test': 20
21 for idx, cls in enumerate(classes): 22 self.cls_to_ind_dict[cls] = idx 23 self.ind_to_cls_dict.append(cls) 24 img_list = sorted(os.listdir(os.path.join(root, split, cls))) 25 self.cls_num[cls] = len(img_list) 26 for img_fp in img_list: 27 self.img_list.append(os.path.join(root, split, cls, img_fp)) 28 self.cls_list.append(idx) 29 30 31 else: 32 img_list_temp, cls_list_temp = [],[] 33
34 for idx, cls in enumerate(classes): 35 self.cls_to_ind_dict[cls] = idx 36 self.ind_to_cls_dict.append(cls) 37 if cls == 'WA': #WA的訓練集數量不用擴 38 img_list = sorted(os.listdir(os.path.join(root, split, cls))) 39 self.cls_num[cls] = len(img_list) 40 for img_fp in img_list: 41 self.img_list.append(os.path.join(root, split, cls, img_fp)) 42 self.cls_list.append(idx) 43 print(cls, '=======================') 44 print(len(self.img_list), len(self.cls_list)) 45 46 else: 47 img_list = sorted(os.listdir(os.path.join(root, split, cls))) 48 49 for img_fp in img_list: 50 img_list_temp.append(os.path.join(root, split, cls, img_fp)) 51 cls_list_temp.append(idx) 52 53 img_list_temp = [val for val in img_list_temp for i in range(3)] #將原來的img_list重復三遍 54 cls_list_temp = [val for val in cls_list_temp for i in range(3)] 55 self.cls_num[cls] = len(img_list_temp) #記錄每個類別的新數目 56 57 print(cls, '=======================') 58 print(len(img_list_temp), len(cls_list_temp)) 59 60 self.img_list = self.img_list + img_list_temp 61 self.cls_list = self.cls_list + cls_list_temp 62 63 print(len(self.img_list), len(self.cls_list)) 64 65 66 # 強制水平翻轉 67 self.trans0 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 68 torchvision.transforms.RandomCrop(224), 69 torchvision.transforms.RandomHorizontalFlip(p=1), 70 torchvision.transforms.ToTensor(), 71 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 72 [0.229, 0.224, 0.225]) 73 ]) 74 # 強制垂直翻轉 75 self.trans1 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 76 torchvision.transforms.RandomCrop(224), 77 torchvision.transforms.RandomVerticalFlip(p=1), 78 torchvision.transforms.ToTensor(), 79 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 80 [0.229, 0.224, 0.225]) 81 ]) 82 # 旋轉-90~90 83 self.trans2 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 84 torchvision.transforms.RandomCrop(224), 85 torchvision.transforms.RandomRotation(90), 86 torchvision.transforms.ToTensor(), 87 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 88 [0.229, 0.224, 0.225]) 89 ]) 90 91 # 亮度在0-2之間增強,0是原圖 92 self.trans3 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 93 torchvision.transforms.RandomCrop(224), 94 torchvision.transforms.ColorJitter(brightness=1), 95 torchvision.transforms.ToTensor(), 96 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 97 [0.229, 0.224, 0.225]) 98 ]) 99 # 修改對比度,0-2之間增強,0是原圖 100 self.trans4 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 101 torchvision.transforms.RandomCrop(224), 102 torchvision.transforms.ColorJitter(contrast=2), 103 torchvision.transforms.ToTensor(), 104 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 105 [0.229, 0.224, 0.225]) 106 ]) 107 # 顏色變化 108 self.trans5 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 109 torchvision.transforms.RandomCrop(224), 110 torchvision.transforms.ColorJitter(hue=0.5), 111 torchvision.transforms.ToTensor(), 112 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 113 [0.229, 0.224, 0.225]) 114 ]) 115 # 混合 116 self.trans6 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 117 torchvision.transforms.RandomCrop(224), 118 torchvision.transforms.ColorJitter(brightness=1, contrast=2, hue=0.5), 119 torchvision.transforms.ToTensor(), 120 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 121 [0.229, 0.224, 0.225]) 122 ]) 123 self.trans_list = [self.trans0, self.trans1, self.trans2, self.trans3, self.trans4, self.trans5, self.trans6] 124 125 126 127 def __getitem__(self, index): 128 name = self.img_list[index] 129 img = Image.open(name) 130 num = random.randint(0, 6) 131 img = self.trans_list[num](img) 132 label = self.cls_list[index] 133 if self.ret_name: 134 return img, label, name 135 else: 136 return img, label 137 138 def __len__(self): 139 return len(self.img_list)
擴展后WKY的訓練集個數為1488,驗證集個數為210,測試集個數依然是142。通過過采樣,無論WA做正例還是負例,得到的靈敏度都相似,不會有非常大的差別。