PyTorch自定義數據集


數據傳遞機制

我們首先回顧識別手寫數字的程序:

...
Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,)
dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True)
...
for epoch in range(EPOCH):
    for i, (image, label) in enumerate(dataloader):
        ...

從上面的程序,我們可以知道,在PyTorch中,數據傳遞機制是這樣的:

  1. 創建Dataset
  2. Dataset傳遞給DataLoader
  3. DataLoader迭代產生訓練數據提供給模型

總結這個數據傳遞機制就是,Dataset負責建立索引到樣本的映射,DataLoader負責以特定的方式從數據集中迭代的產生一個個batch的樣本集合。在enumerate過程中實際上是dataloader按照其參數sampler規定的策略調用了其dataset的getitem方法(下文中將介紹該方法)。關於Dataloder和Dataset的關系,具體可參考博客PyTorch中Dataset, DataLoader, Sampler的關系

在上面的識別手寫數字的例子中,數據集是直接下載的,但如果我們自己收集了一些數據,存在電腦文件夾里,我們該如何把這些數據變為可以在PyTorch框架下進行神經網絡訓練的數據集呢,即如何自定義數據集呢?

自定義數據集

torch.utils.data.Dataset 是一個表示數據集的抽象類。任何自定義的數據集都需要繼承這個類並覆寫相關方法。所謂數據集,其實就是一個負責處理索引(index)到樣本(sample)映射的一個類(class)。Pytorch提供兩種數據集: Map式數據集 Iterable式數據集。這里我們只介紹前者。

一個Map式的數據集必須要重寫getitem(self, index)、 len(self) 兩個內建方法,用來表示從索引到樣本的映射(Map)。這樣一個數據集dataset,舉個例子,當使用dataset[idx]命令時,可以在你的硬盤中讀取數據集中第idx張圖片以及其標簽(如果有的話); len(dataset)則會返回這個數據集的容量。

自定義數據集類的范式大致是這樣的:

class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #這里需要注意的是,第一步:read one data,是一個data point
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

根據這個范式,我們舉一個例子。

實例

從kaggle官網下載dogsVScats的數據集(百度網盤下載鏈接見文末),該數據集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號:

cat.0.jpg
cat.1.jpg
cat.2.jpg
...
cat.12499.jpg
dog.0.jpg
dog.1.jpg
dog.2.jpg
...
dog.12499.jpg

我們把其中前10000張貓的圖片和10000張狗的圖片作為訓練集,把后面的2500張貓的圖片和2500張狗的圖片作為驗證集。貓的label記為0,狗的label記為1。因為圖片大小不一,所以,我們需要對圖像進行transform。

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
image_transform = transforms.Compose([
    transforms.Resize(256),               # 把圖片resize為256*256
    transforms.RandomCrop(224),           # 隨機裁剪224*224
    transforms.RandomHorizontalFlip(),    # 水平翻轉
    transforms.ToTensor(),                # 將圖像轉為Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   # 標准化
])

class DogVsCatDataset(Dataset):   # 創建一個叫做DogVsCatDataset的Dataset,繼承自父類torch.utils.data.Dataset
    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.img_path = os.listdir(self.root_dir)
        if train:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))    # 划分訓練集和驗證集
        else:
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
        self.transform = transform

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
        label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1        # label, 貓為0,狗為1
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array([label]))
        return image, label

我們來測試一下:

if __name__ == '__main__':
    catanddog_dataset = DogVsCatDataset(root_dir='/Users/wangpeng/Desktop/train',
                                        train=False,
                                        transform=image_transform)
    train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)   # num_workers=4表示用4個線程讀取數據
    image, label = iter(train_loader).next()   # iter()函數把train_loader變為迭代器,然后調用迭代器的next()方法
    sample = image[0].squeeze()
    sample = sample.permute((1, 2, 0)).numpy()
    sample *= [0.229, 0.224, 0.225]
    sample += [0.485, 0.456, 0.406]
    sample = np.clip(sample, 0, 1)
    plt.imshow(sample)
    plt.show()
    print('Label is: {}'.format(label[0].numpy()))

運行結果:

Label is: [0] 

 

dogsVScats數據下載鏈接:鏈接:https://pan.baidu.com/s/17768gqeaX9NrdURV_tR_ow  提取密碼:478x

 

參考文獻

[1] Pytorch之Dataset與DataLoader,打造你自己的數據集

[2] 基於PyTorch的卷積神經網絡圖像分類——貓狗大戰(一):使用Pytorch定義DataLoader

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM