數據傳遞機制
我們首先回顧識別手寫數字的程序:
...
Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,)
dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True)
...
for epoch in range(EPOCH):
for i, (image, label) in enumerate(dataloader):
...
從上面的程序,我們可以知道,在PyTorch中,數據傳遞機制是這樣的:
- 創建Dataset
- Dataset傳遞給DataLoader
- DataLoader迭代產生訓練數據提供給模型
總結這個數據傳遞機制就是,Dataset負責建立索引到樣本的映射,DataLoader負責以特定的方式從數據集中迭代的產生一個個batch的樣本集合。在enumerate過程中實際上是dataloader按照其參數sampler規定的策略調用了其dataset的getitem方法(下文中將介紹該方法)。關於Dataloder和Dataset的關系,具體可參考博客PyTorch中Dataset, DataLoader, Sampler的關系
在上面的識別手寫數字的例子中,數據集是直接下載的,但如果我們自己收集了一些數據,存在電腦文件夾里,我們該如何把這些數據變為可以在PyTorch框架下進行神經網絡訓練的數據集呢,即如何自定義數據集呢?
自定義數據集
torch.utils.data.Dataset 是一個表示數據集的抽象類。任何自定義的數據集都需要繼承這個類並覆寫相關方法。所謂數據集,其實就是一個負責處理索引(index)到樣本(sample)映射的一個類(class)。Pytorch提供兩種數據集: Map式數據集 Iterable式數據集。這里我們只介紹前者。
一個Map式的數據集必須要重寫getitem(self, index)、 len(self) 兩個內建方法,用來表示從索引到樣本的映射(Map)。這樣一個數據集dataset,舉個例子,當使用dataset[idx]命令時,可以在你的硬盤中讀取數據集中第idx張圖片以及其標簽(如果有的話); len(dataset)則會返回這個數據集的容量。
自定義數據集類的范式大致是這樣的:
class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#這里需要注意的是,第一步:read one data,是一個data point
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
根據這個范式,我們舉一個例子。
實例
從kaggle官網下載dogsVScats的數據集(百度網盤下載鏈接見文末),該數據集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號:
cat.0.jpg cat.1.jpg cat.2.jpg ... cat.12499.jpg dog.0.jpg dog.1.jpg dog.2.jpg ... dog.12499.jpg
我們把其中前10000張貓的圖片和10000張狗的圖片作為訓練集,把后面的2500張貓的圖片和2500張狗的圖片作為驗證集。貓的label記為0,狗的label記為1。因為圖片大小不一,所以,我們需要對圖像進行transform。
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
image_transform = transforms.Compose([
transforms.Resize(256), # 把圖片resize為256*256
transforms.RandomCrop(224), # 隨機裁剪224*224
transforms.RandomHorizontalFlip(), # 水平翻轉
transforms.ToTensor(), # 將圖像轉為Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 標准化
])
class DogVsCatDataset(Dataset): # 創建一個叫做DogVsCatDataset的Dataset,繼承自父類torch.utils.data.Dataset
def __init__(self, root_dir, train=True, transform=None):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.root_dir = root_dir
self.img_path = os.listdir(self.root_dir)
if train:
self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path)) # 划分訓練集和驗證集
else:
self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
self.transform = transform
def __len__(self):
return len(self.img_path)
def __getitem__(self, idx):
image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1 # label, 貓為0,狗為1
if self.transform:
image = self.transform(image)
label = torch.from_numpy(np.array([label]))
return image, label
我們來測試一下:
if __name__ == '__main__':
catanddog_dataset = DogVsCatDataset(root_dir='/Users/wangpeng/Desktop/train',
train=False,
transform=image_transform)
train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4) # num_workers=4表示用4個線程讀取數據
image, label = iter(train_loader).next() # iter()函數把train_loader變為迭代器,然后調用迭代器的next()方法
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
sample = np.clip(sample, 0, 1)
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
運行結果:

Label is: [0]
dogsVScats數據下載鏈接:鏈接:https://pan.baidu.com/s/17768gqeaX9NrdURV_tR_ow 提取密碼:478x
參考文獻
[1] Pytorch之Dataset與DataLoader,打造你自己的數據集
[2] 基於PyTorch的卷積神經網絡圖像分類——貓狗大戰(一):使用Pytorch定義DataLoader
