數據傳遞機制
我們首先回顧識別手寫數字的程序:
... Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,) dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True) ... for epoch in range(EPOCH): for i, (image, label) in enumerate(dataloader): ...
從上面的程序,我們可以知道,在PyTorch中,數據傳遞機制是這樣的:
- 創建Dataset
- Dataset傳遞給DataLoader
- DataLoader迭代產生訓練數據提供給模型
總結這個數據傳遞機制就是,Dataset負責建立索引到樣本的映射,DataLoader負責以特定的方式從數據集中迭代的產生一個個batch的樣本集合。在enumerate過程中實際上是dataloader按照其參數sampler規定的策略調用了其dataset的getitem方法(下文中將介紹該方法)。關於Dataloder和Dataset的關系,具體可參考博客PyTorch中Dataset, DataLoader, Sampler的關系
在上面的識別手寫數字的例子中,數據集是直接下載的,但如果我們自己收集了一些數據,存在電腦文件夾里,我們該如何把這些數據變為可以在PyTorch框架下進行神經網絡訓練的數據集呢,即如何自定義數據集呢?
自定義數據集
torch.utils.data.Dataset 是一個表示數據集的抽象類。任何自定義的數據集都需要繼承這個類並覆寫相關方法。所謂數據集,其實就是一個負責處理索引(index)到樣本(sample)映射的一個類(class)。Pytorch提供兩種數據集: Map式數據集 Iterable式數據集。這里我們只介紹前者。
一個Map式的數據集必須要重寫getitem(self, index)、 len(self) 兩個內建方法,用來表示從索引到樣本的映射(Map)。這樣一個數據集dataset,舉個例子,當使用dataset[idx]命令時,可以在你的硬盤中讀取數據集中第idx張圖片以及其標簽(如果有的話); len(dataset)則會返回這個數據集的容量。
自定義數據集類的范式大致是這樣的:
class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Dataset def __init__(self): # TODO # 1. Initialize file path or list of file names. pass def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #這里需要注意的是,第一步:read one data,是一個data point pass def __len__(self): # You should change 0 to the total size of your dataset. return 0
根據這個范式,我們舉一個例子。
實例
從kaggle官網下載dogsVScats的數據集(百度網盤下載鏈接見文末),該數據集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號:
cat.0.jpg cat.1.jpg cat.2.jpg ... cat.12499.jpg dog.0.jpg dog.1.jpg dog.2.jpg ... dog.12499.jpg
我們把其中前10000張貓的圖片和10000張狗的圖片作為訓練集,把后面的2500張貓的圖片和2500張狗的圖片作為驗證集。貓的label記為0,狗的label記為1。因為圖片大小不一,所以,我們需要對圖像進行transform。
import matplotlib.pyplot as plt import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os image_transform = transforms.Compose([ transforms.Resize(256), # 把圖片resize為256*256 transforms.RandomCrop(224), # 隨機裁剪224*224 transforms.RandomHorizontalFlip(), # 水平翻轉 transforms.ToTensor(), # 將圖像轉為Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 標准化 ]) class DogVsCatDataset(Dataset): # 創建一個叫做DogVsCatDataset的Dataset,繼承自父類torch.utils.data.Dataset def __init__(self, root_dir, train=True, transform=None): """ Args: root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.root_dir = root_dir self.img_path = os.listdir(self.root_dir) if train: self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path)) # 划分訓練集和驗證集 else: self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path)) self.transform = transform def __len__(self): return len(self.img_path) def __getitem__(self, idx): image = Image.open(os.path.join(self.root_dir, self.img_path[idx])) label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1 # label, 貓為0,狗為1 if self.transform: image = self.transform(image) label = torch.from_numpy(np.array([label])) return image, label
我們來測試一下:
if __name__ == '__main__': catanddog_dataset = DogVsCatDataset(root_dir='/Users/wangpeng/Desktop/train', train=False, transform=image_transform) train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4) # num_workers=4表示用4個線程讀取數據 image, label = iter(train_loader).next() # iter()函數把train_loader變為迭代器,然后調用迭代器的next()方法 sample = image[0].squeeze() sample = sample.permute((1, 2, 0)).numpy() sample *= [0.229, 0.224, 0.225] sample += [0.485, 0.456, 0.406] sample = np.clip(sample, 0, 1) plt.imshow(sample) plt.show() print('Label is: {}'.format(label[0].numpy()))
運行結果:
Label is: [0]
dogsVScats數據下載鏈接:鏈接:https://pan.baidu.com/s/17768gqeaX9NrdURV_tR_ow 提取密碼:478x
參考文獻
[1] Pytorch之Dataset與DataLoader,打造你自己的數據集
[2] 基於PyTorch的卷積神經網絡圖像分類——貓狗大戰(一):使用Pytorch定義DataLoader