(一)Pytorch處理VOC數據集以及dataset代碼的編寫


1、Dataset+DataLoader實現自定義數據集讀取方法

創建自己的數據集需要繼承父類torch.utils.data.Dataset,同時需要重載兩個私有成員函數:def __len__(self)和def __getitem__(self, index) 。 def __len__(self)應該返回數據集的大小;def __getitem__(self, index)接收一個index,然后返回圖片數據和標簽,這個index通常指的是一個list的index,這個list的每個元素就包含了圖片數據的路徑和標簽信息。如何制作這個list呢,通常的方法是將圖片的路徑和標簽信息存儲在一個txt中,然后從該txt中讀取
整個流程如下:

 

 

 1.1整體框架

 1 class MyDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Dataset
 2     def __init__(self):
 3         #對繼承自父類的屬性進行初始化(好像沒有這句也可以??)
 4         super(MyDataset,self).__init__()
 5         # TODO
 6         #1、初始化一些參數和函數,方便在__getitem__函數中調用。
 7         #2、制作__getitem__函數所要用到的圖片和對應標簽的list。
 8         #也就是在這個模塊里,我們所做的工作就是初始化該類的一些基本參數。
 9         pass
10     def __getitem__(self, index):
11         # TODO
12         #1、根據list從文件中讀取一個數據(例如,使用numpy.fromfile,PIL.Image.open)。
13         #2、預處理數據(例如torchvision.Transform)。
14         #3、返回數據對(例如圖像和標簽)。
15         #這里需要注意的是,這步所處理的是index所對應的一個樣本。
16         pass
17     def __len__(self):
18         #返回數據集大小
19         return len()

1.2例子講解

下面用VOC2012數據集的處理作為例子,我們使用UNet網絡處理數據集。

1.2.1VOC數據集的介紹

 

 

 這是下載的一個voc2012的數據集,上邊的是測試集,下邊的是訓練+驗證集。

VOCdevkit_train結構如下:

 

 

 

該數據集不全是用來做圖像分割的,也有目標檢測等其他任務的。
Annotations是保存每一張圖片的xml信息,其中包含是否用來分割等選項,目標檢測Xmin,Xmax,Ymin,Ymax等信息,此處我們用不到
JPEGimages就是需要用到的圖片
SegmentationClass是語義分割的label,需要復制整個文件夾
SegmentationObject是實例分割的,暫時用不到
ImageSets這里需要用到,里邊有各種任務的txt文件--我們可以根據此來找圖像分割需要用到的圖片,下圖為ImageSets里的文件:

 

 

 我們這里只用到下圖的Segmentation部分,打開可以看到以下幾個文件,我們可以根據這幾個txt文件把圖像分割需要的數據提取出去。

 

 

1.2.2具體代碼 

 1 import os
 2 import os.path as osp
 3 import logging
 4 import pandas as pd
 5 import numpy as np
 6 from PIL import Image
 7 import torch
 8 from torch.utils.data import Dataset
 9 from PIL import Image
10 from cv2 import cv2 as cv
11 
12 class BasicDataset(Dataset):
13     def __init__(self, imgs_dir, labels_dir, text_dir):
14         self.imgs_dir = imgs_dir
15         self.labels_dir = labels_dir
16         f = open(text_dir, 'r')
17         self.img_names = f.read().splitlines()
18         logging.info(f'Creating dataset with {len(self.img_names)} examples')
19 
20     def __len__(self):
21         return len(self.img_names)
22 
23     def preprocess(self, img):#對圖像進行處理的函數
24         img = np.array(img)
25         if len(img.shape) == 2:#如果img數組的維度是二維,就把它變成三維的
26             # label target image
27             img = np.expand_dims(img, axis=2)#在最后增加一維
28 
29         img = img / 255.0 #把矩陣中的像素點值變成0-1范圍中的數,這么做的目的是讓目標與背景的差異變得明顯
30         # HWC to CHW
31         img_trans = img.transpose((2, 0, 1))
32         #return img_trans
33 
34     def __getitem__(self, i):#這里的i就是從文件夾中提取的第i張圖片,當作一個索引
35         img_name = self.img_names[i] #第i個圖片的名字
36         img_path = osp.join(self.imgs_dir, img_name+'.jpg') #第i個圖片的名字和路徑
37         label_path = osp.join(self.labels_dir, img_name+'.png')#第i個標簽圖片的名字和路徑
38 
39         #生成像素點的標簽矩陣
40         label_img = Image.open(label_path)
41         label_img = np.array(label_img) #生成標簽矩陣
42         # print label image
43         # label_img[label_img == 255] = 0
44         # cv.imwrite('/home/zms/zz/segmentation_models.pytorch-master/tests/test.jpg', label_img)
45 
46         img = Image.open(img_path)
47         img = self.preprocess(img) #這是處理后的圖片
48 
49         assert img.shape[1:] == label_img.shape, \
50             f'Image and label {img_name} should be the same size, but are {img.size} and {label_img.size}'
51 
52         return {'image': torch.from_numpy(img).type(torch.float), 'label': torch.from_numpy(label_img).type(torch.uint8)} #就是torch.from_numpy()方法把數組轉換成張量,且二者共享內存

(1)首先是對datasets類的屬性進行初始化,在_init_中定義了(self,imgs_dir, labels_dir, text_dir)這里。我們定義了imgs、labels、text三個屬性,這里的作用是輸入這三個的路徑。

(2)因為該數據集不全是用來做圖像分割的,所以我們要先從ImageSets/Segmentation中讀取train.txt文件,通過這個文件再從JPEGimages中取出對應的圖片,具體代碼:

1 f = open(text_dir, 'r')#讀取一個txt文件,路徑為text_dir
2 self.img_names = f.read().splitlines()#取出text文件中的信息

Python中splitlines()函數的作用是:在定義了行邊界的字符串中返回行的列表。除非指定了 keepends 參數,且把其值設置為 True, 否則行的邊界符默認不會包含在字符串中。例如:

>>> str1 = 'ab c\n\nde fg\rkl\r\n'
>>> print(str1.splitlines())
['ab c', '', 'de fg', 'kl']
>>> print(str1.splitlines(True))
['ab c\n', '\n', 'de fg\r', 'kl\r\n']

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM