pytorch讀取圖片,主要是通過Dataset類。
Dataset類源代碼如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
這個類中最核心的就是getitem函數,上面介紹中寫的是這個函數提供一個合理范圍內的index。我們在自己定義數據集的時候,在這個類中,我們一般是定義這個函數的功能是接受一個index,然后返回圖片數據和標簽。所以在這個函數中,需要包含打開圖片的函數和獲取圖片lable的語句
getitem函數接受的是一個index,這個index通常指的是一個list中index,這個list中的每個元素就是對應的每個圖片的文件路徑和標簽。
所以在讀取自己數據的時候基本流程就是這樣的:
首先制作圖片存儲路徑和標簽信息的txt
然后將這個信息轉化為list
通過這個list中的index,使用getitem函數,我們獲取對應的圖片數據和標簽信息
現在問題是如何制作這個一個list。這個東西我們一般是外部制作就好,保存為一個txt格式就好
然后我們制作一個Dataset子類
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.strip()
words = line.split()
imgs.append((words[0], int(words[1]))) # words[0]是路徑 words[1]是類別數
self.imgs = imgs # 最主要就是要生成這個list, 然后DataLoader中給index,通過getitem讀取圖片數據
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor會除以255,使像素值變成 0~1
if self.transform is not None:
img = self.transform(img) # 在這里做transform,轉為tensor等等
return img, label
def __len__(self):
return len(self.imgs)
注意看我自己定義的類,在初始化函數中,我通過對txt文件的讀取,得到了一個list,也就是self.imgs
然后在__getitem__ 函數中,通過index,我們得到文件路徑和lable,然后使用open函數,將圖像文件打開並轉化為RGB數據,同時進行一些相應的轉化
這個部分建立好了,其實自定義數據集基本就好了,因為接下來的操作就交給了DataLoder,代碼基本不需要變化。
我現在有一個思考,就是說上面我說的圖像數據,如果是文本數據呢?我如何進行自定義數據呢?