(一)Pytorch处理VOC数据集以及dataset代码的编写


1、Dataset+DataLoader实现自定义数据集读取方法

创建自己的数据集需要继承父类torch.utils.data.Dataset,同时需要重载两个私有成员函数:def __len__(self)和def __getitem__(self, index) 。 def __len__(self)应该返回数据集的大小;def __getitem__(self, index)接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取
整个流程如下:

 

 

 1.1整体框架

 1 class MyDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Dataset
 2     def __init__(self):
 3         #对继承自父类的属性进行初始化(好像没有这句也可以??)
 4         super(MyDataset,self).__init__()
 5         # TODO
 6         #1、初始化一些参数和函数,方便在__getitem__函数中调用。
 7         #2、制作__getitem__函数所要用到的图片和对应标签的list。
 8         #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
 9         pass
10     def __getitem__(self, index):
11         # TODO
12         #1、根据list从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
13         #2、预处理数据(例如torchvision.Transform)。
14         #3、返回数据对(例如图像和标签)。
15         #这里需要注意的是,这步所处理的是index所对应的一个样本。
16         pass
17     def __len__(self):
18         #返回数据集大小
19         return len()

1.2例子讲解

下面用VOC2012数据集的处理作为例子,我们使用UNet网络处理数据集。

1.2.1VOC数据集的介绍

 

 

 这是下载的一个voc2012的数据集,上边的是测试集,下边的是训练+验证集。

VOCdevkit_train结构如下:

 

 

 

该数据集不全是用来做图像分割的,也有目标检测等其他任务的。
Annotations是保存每一张图片的xml信息,其中包含是否用来分割等选项,目标检测Xmin,Xmax,Ymin,Ymax等信息,此处我们用不到
JPEGimages就是需要用到的图片
SegmentationClass是语义分割的label,需要复制整个文件夹
SegmentationObject是实例分割的,暂时用不到
ImageSets这里需要用到,里边有各种任务的txt文件--我们可以根据此来找图像分割需要用到的图片,下图为ImageSets里的文件:

 

 

 我们这里只用到下图的Segmentation部分,打开可以看到以下几个文件,我们可以根据这几个txt文件把图像分割需要的数据提取出去。

 

 

1.2.2具体代码 

 1 import os
 2 import os.path as osp
 3 import logging
 4 import pandas as pd
 5 import numpy as np
 6 from PIL import Image
 7 import torch
 8 from torch.utils.data import Dataset
 9 from PIL import Image
10 from cv2 import cv2 as cv
11 
12 class BasicDataset(Dataset):
13     def __init__(self, imgs_dir, labels_dir, text_dir):
14         self.imgs_dir = imgs_dir
15         self.labels_dir = labels_dir
16         f = open(text_dir, 'r')
17         self.img_names = f.read().splitlines()
18         logging.info(f'Creating dataset with {len(self.img_names)} examples')
19 
20     def __len__(self):
21         return len(self.img_names)
22 
23     def preprocess(self, img):#对图像进行处理的函数
24         img = np.array(img)
25         if len(img.shape) == 2:#如果img数组的维度是二维,就把它变成三维的
26             # label target image
27             img = np.expand_dims(img, axis=2)#在最后增加一维
28 
29         img = img / 255.0 #把矩阵中的像素点值变成0-1范围中的数,这么做的目的是让目标与背景的差异变得明显
30         # HWC to CHW
31         img_trans = img.transpose((2, 0, 1))
32         #return img_trans
33 
34     def __getitem__(self, i):#这里的i就是从文件夹中提取的第i张图片,当作一个索引
35         img_name = self.img_names[i] #第i个图片的名字
36         img_path = osp.join(self.imgs_dir, img_name+'.jpg') #第i个图片的名字和路径
37         label_path = osp.join(self.labels_dir, img_name+'.png')#第i个标签图片的名字和路径
38 
39         #生成像素点的标签矩阵
40         label_img = Image.open(label_path)
41         label_img = np.array(label_img) #生成标签矩阵
42         # print label image
43         # label_img[label_img == 255] = 0
44         # cv.imwrite('/home/zms/zz/segmentation_models.pytorch-master/tests/test.jpg', label_img)
45 
46         img = Image.open(img_path)
47         img = self.preprocess(img) #这是处理后的图片
48 
49         assert img.shape[1:] == label_img.shape, \
50             f'Image and label {img_name} should be the same size, but are {img.size} and {label_img.size}'
51 
52         return {'image': torch.from_numpy(img).type(torch.float), 'label': torch.from_numpy(label_img).type(torch.uint8)} #就是torch.from_numpy()方法把数组转换成张量,且二者共享内存

(1)首先是对datasets类的属性进行初始化,在_init_中定义了(self,imgs_dir, labels_dir, text_dir)这里。我们定义了imgs、labels、text三个属性,这里的作用是输入这三个的路径。

(2)因为该数据集不全是用来做图像分割的,所以我们要先从ImageSets/Segmentation中读取train.txt文件,通过这个文件再从JPEGimages中取出对应的图片,具体代码:

1 f = open(text_dir, 'r')#读取一个txt文件,路径为text_dir
2 self.img_names = f.read().splitlines()#取出text文件中的信息

Python中splitlines()函数的作用是:在定义了行边界的字符串中返回行的列表。除非指定了 keepends 参数,且把其值设置为 True, 否则行的边界符默认不会包含在字符串中。例如:

>>> str1 = 'ab c\n\nde fg\rkl\r\n'
>>> print(str1.splitlines())
['ab c', '', 'de fg', 'kl']
>>> print(str1.splitlines(True))
['ab c\n', '\n', 'de fg\r', 'kl\r\n']

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM