1、Dataset+DataLoader实现自定义数据集读取方法
创建自己的数据集需要继承父类torch.utils.data.Dataset,同时需要重载两个私有成员函数:def __len__(self)和def __getitem__(self, index) 。 def __len__(self)应该返回数据集的大小;def __getitem__(self, index)接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
整个流程如下:
1.1整体框架
1 class MyDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Dataset 2 def __init__(self): 3 #对继承自父类的属性进行初始化(好像没有这句也可以??) 4 super(MyDataset,self).__init__() 5 # TODO 6 #1、初始化一些参数和函数,方便在__getitem__函数中调用。 7 #2、制作__getitem__函数所要用到的图片和对应标签的list。 8 #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。 9 pass 10 def __getitem__(self, index): 11 # TODO 12 #1、根据list从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。 13 #2、预处理数据(例如torchvision.Transform)。 14 #3、返回数据对(例如图像和标签)。 15 #这里需要注意的是,这步所处理的是index所对应的一个样本。 16 pass 17 def __len__(self): 18 #返回数据集大小 19 return len()
1.2例子讲解
下面用VOC2012数据集的处理作为例子,我们使用UNet网络处理数据集。
1.2.1VOC数据集的介绍
这是下载的一个voc2012的数据集,上边的是测试集,下边的是训练+验证集。
VOCdevkit_train结构如下:
该数据集不全是用来做图像分割的,也有目标检测等其他任务的。
Annotations是保存每一张图片的xml信息,其中包含是否用来分割等选项,目标检测Xmin,Xmax,Ymin,Ymax等信息,此处我们用不到
JPEGimages就是需要用到的图片
SegmentationClass是语义分割的label,需要复制整个文件夹
SegmentationObject是实例分割的,暂时用不到
ImageSets这里需要用到,里边有各种任务的txt文件--我们可以根据此来找图像分割需要用到的图片,下图为ImageSets里的文件:
我们这里只用到下图的Segmentation部分,打开可以看到以下几个文件,我们可以根据这几个txt文件把图像分割需要的数据提取出去。
1.2.2具体代码
1 import os 2 import os.path as osp 3 import logging 4 import pandas as pd 5 import numpy as np 6 from PIL import Image 7 import torch 8 from torch.utils.data import Dataset 9 from PIL import Image 10 from cv2 import cv2 as cv 11 12 class BasicDataset(Dataset): 13 def __init__(self, imgs_dir, labels_dir, text_dir): 14 self.imgs_dir = imgs_dir 15 self.labels_dir = labels_dir 16 f = open(text_dir, 'r') 17 self.img_names = f.read().splitlines() 18 logging.info(f'Creating dataset with {len(self.img_names)} examples') 19 20 def __len__(self): 21 return len(self.img_names) 22 23 def preprocess(self, img):#对图像进行处理的函数 24 img = np.array(img) 25 if len(img.shape) == 2:#如果img数组的维度是二维,就把它变成三维的 26 # label target image 27 img = np.expand_dims(img, axis=2)#在最后增加一维 28 29 img = img / 255.0 #把矩阵中的像素点值变成0-1范围中的数,这么做的目的是让目标与背景的差异变得明显 30 # HWC to CHW 31 img_trans = img.transpose((2, 0, 1)) 32 #return img_trans 33 34 def __getitem__(self, i):#这里的i就是从文件夹中提取的第i张图片,当作一个索引 35 img_name = self.img_names[i] #第i个图片的名字 36 img_path = osp.join(self.imgs_dir, img_name+'.jpg') #第i个图片的名字和路径 37 label_path = osp.join(self.labels_dir, img_name+'.png')#第i个标签图片的名字和路径 38 39 #生成像素点的标签矩阵 40 label_img = Image.open(label_path) 41 label_img = np.array(label_img) #生成标签矩阵 42 # print label image 43 # label_img[label_img == 255] = 0 44 # cv.imwrite('/home/zms/zz/segmentation_models.pytorch-master/tests/test.jpg', label_img) 45 46 img = Image.open(img_path) 47 img = self.preprocess(img) #这是处理后的图片 48 49 assert img.shape[1:] == label_img.shape, \ 50 f'Image and label {img_name} should be the same size, but are {img.size} and {label_img.size}' 51 52 return {'image': torch.from_numpy(img).type(torch.float), 'label': torch.from_numpy(label_img).type(torch.uint8)} #就是torch.from_numpy()方法把数组转换成张量,且二者共享内存
(1)首先是对datasets类的属性进行初始化,在_init_中定义了(self,imgs_dir, labels_dir, text_dir)这里。我们定义了imgs、labels、text三个属性,这里的作用是输入这三个的路径。
(2)因为该数据集不全是用来做图像分割的,所以我们要先从ImageSets/Segmentation中读取train.txt文件,通过这个文件再从JPEGimages中取出对应的图片,具体代码:
1 f = open(text_dir, 'r')#读取一个txt文件,路径为text_dir 2 self.img_names = f.read().splitlines()#取出text文件中的信息
Python中splitlines()函数的作用是:在定义了行边界的字符串中返回行的列表。除非指定了 keepends 参数,且把其值设置为 True, 否则行的边界符默认不会包含在字符串中。例如:
>>> str1 = 'ab c\n\nde fg\rkl\r\n' >>> print(str1.splitlines()) ['ab c', '', 'de fg', 'kl'] >>> print(str1.splitlines(True)) ['ab c\n', '\n', 'de fg\r', 'kl\r\n']