1、Dataset+DataLoader實現自定義數據集讀取方法
創建自己的數據集需要繼承父類torch.utils.data.Dataset,同時需要重載兩個私有成員函數:def __len__(self)和def __getitem__(self, index) 。 def __len__(self)應該返回數據集的大小;def __getitem__(self, index)接收一個index,然后返回圖片數據和標簽,這個index通常指的是一個list的index,這個list的每個元素就包含了圖片數據的路徑和標簽信息。如何制作這個list呢,通常的方法是將圖片的路徑和標簽信息存儲在一個txt中,然后從該txt中讀取。
整個流程如下:
1.1整體框架
1 class MyDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Dataset 2 def __init__(self): 3 #對繼承自父類的屬性進行初始化(好像沒有這句也可以??) 4 super(MyDataset,self).__init__() 5 # TODO 6 #1、初始化一些參數和函數,方便在__getitem__函數中調用。 7 #2、制作__getitem__函數所要用到的圖片和對應標簽的list。 8 #也就是在這個模塊里,我們所做的工作就是初始化該類的一些基本參數。 9 pass 10 def __getitem__(self, index): 11 # TODO 12 #1、根據list從文件中讀取一個數據(例如,使用numpy.fromfile,PIL.Image.open)。 13 #2、預處理數據(例如torchvision.Transform)。 14 #3、返回數據對(例如圖像和標簽)。 15 #這里需要注意的是,這步所處理的是index所對應的一個樣本。 16 pass 17 def __len__(self): 18 #返回數據集大小 19 return len()
1.2例子講解
下面用VOC2012數據集的處理作為例子,我們使用UNet網絡處理數據集。
1.2.1VOC數據集的介紹
這是下載的一個voc2012的數據集,上邊的是測試集,下邊的是訓練+驗證集。
VOCdevkit_train結構如下:
該數據集不全是用來做圖像分割的,也有目標檢測等其他任務的。
Annotations是保存每一張圖片的xml信息,其中包含是否用來分割等選項,目標檢測Xmin,Xmax,Ymin,Ymax等信息,此處我們用不到
JPEGimages就是需要用到的圖片
SegmentationClass是語義分割的label,需要復制整個文件夾
SegmentationObject是實例分割的,暫時用不到
ImageSets這里需要用到,里邊有各種任務的txt文件--我們可以根據此來找圖像分割需要用到的圖片,下圖為ImageSets里的文件:
我們這里只用到下圖的Segmentation部分,打開可以看到以下幾個文件,我們可以根據這幾個txt文件把圖像分割需要的數據提取出去。
1.2.2具體代碼
1 import os 2 import os.path as osp 3 import logging 4 import pandas as pd 5 import numpy as np 6 from PIL import Image 7 import torch 8 from torch.utils.data import Dataset 9 from PIL import Image 10 from cv2 import cv2 as cv 11 12 class BasicDataset(Dataset): 13 def __init__(self, imgs_dir, labels_dir, text_dir): 14 self.imgs_dir = imgs_dir 15 self.labels_dir = labels_dir 16 f = open(text_dir, 'r') 17 self.img_names = f.read().splitlines() 18 logging.info(f'Creating dataset with {len(self.img_names)} examples') 19 20 def __len__(self): 21 return len(self.img_names) 22 23 def preprocess(self, img):#對圖像進行處理的函數 24 img = np.array(img) 25 if len(img.shape) == 2:#如果img數組的維度是二維,就把它變成三維的 26 # label target image 27 img = np.expand_dims(img, axis=2)#在最后增加一維 28 29 img = img / 255.0 #把矩陣中的像素點值變成0-1范圍中的數,這么做的目的是讓目標與背景的差異變得明顯 30 # HWC to CHW 31 img_trans = img.transpose((2, 0, 1)) 32 #return img_trans 33 34 def __getitem__(self, i):#這里的i就是從文件夾中提取的第i張圖片,當作一個索引 35 img_name = self.img_names[i] #第i個圖片的名字 36 img_path = osp.join(self.imgs_dir, img_name+'.jpg') #第i個圖片的名字和路徑 37 label_path = osp.join(self.labels_dir, img_name+'.png')#第i個標簽圖片的名字和路徑 38 39 #生成像素點的標簽矩陣 40 label_img = Image.open(label_path) 41 label_img = np.array(label_img) #生成標簽矩陣 42 # print label image 43 # label_img[label_img == 255] = 0 44 # cv.imwrite('/home/zms/zz/segmentation_models.pytorch-master/tests/test.jpg', label_img) 45 46 img = Image.open(img_path) 47 img = self.preprocess(img) #這是處理后的圖片 48 49 assert img.shape[1:] == label_img.shape, \ 50 f'Image and label {img_name} should be the same size, but are {img.size} and {label_img.size}' 51 52 return {'image': torch.from_numpy(img).type(torch.float), 'label': torch.from_numpy(label_img).type(torch.uint8)} #就是torch.from_numpy()方法把數組轉換成張量,且二者共享內存
(1)首先是對datasets類的屬性進行初始化,在_init_中定義了(self,imgs_dir, labels_dir, text_dir)這里。我們定義了imgs、labels、text三個屬性,這里的作用是輸入這三個的路徑。
(2)因為該數據集不全是用來做圖像分割的,所以我們要先從ImageSets/Segmentation中讀取train.txt文件,通過這個文件再從JPEGimages中取出對應的圖片,具體代碼:
1 f = open(text_dir, 'r')#讀取一個txt文件,路徑為text_dir 2 self.img_names = f.read().splitlines()#取出text文件中的信息
Python中splitlines()函數的作用是:在定義了行邊界的字符串中返回行的列表。除非指定了 keepends 參數,且把其值設置為 True, 否則行的邊界符默認不會包含在字符串中。例如:
>>> str1 = 'ab c\n\nde fg\rkl\r\n' >>> print(str1.splitlines()) ['ab c', '', 'de fg', 'kl'] >>> print(str1.splitlines(True)) ['ab c\n', '\n', 'de fg\r', 'kl\r\n']