[深度學習] pytorch利用Datasets和DataLoader讀取數據


 

本文簡單描述如果自定義dataset,代碼並未經過測試(只是說明思路),為半偽代碼。所有邏輯需按自己需求另外實現:

 

一、分析DataLoader

train_loader = DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

datasets.MNIST()是一個torch.utils.data.Datasets對象,batch_size表示我們定義的batch大小(即每輪訓練使用的批大小),shuffle表示是否打亂數據順序(對於整個datasets里包含的所有數據)。

對於batch_size和shuffle都是根據業務需求來認為指定的,不做過多說明。

對於Datasets對象來說,我們可以根據自己的數據類型來自定義,自己定義一個類,繼承Datasets類。

二、分析Datasets類

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

上述代碼是pytorch中Datasets的源碼,注意成員方法__getitem__和__len__都是未實現的。我們要實現自定義Datasets類來完成數據的讀取,則只需要完成這兩個成員方法的重寫。

  首先,__getitem__()方法用來從datasets中讀取一條數據,這條數據包含訓練圖片(已CV距離)和標簽,參數index表示圖片和標簽在總數據集中的Index。

  __len__()方法返回數據集的總長度(訓練集的總數)。

三、簡單實現MyDatasets類

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z'

import os

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.image as mpimg


# 對所有圖片生成path-label map.txt
def generate_map(root_dir):
    current_path = os.path.abspath(__file__)
    father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".")

    with open(root_dir + 'map.txt', 'w') as wfp:
        for idx in range(10):
            subdir = os.path.join(root_dir, '%d/' % idx)
            for file_name in os.listdir(subdir):
                abs_name = os.path.join(father_path, subdir, file_name)
                linux_abs_name = abs_name.replace("\\", '/')
                wfp.write('{file_dir} {label}\n'.format(file_dir=linux_abs_name, label=idx))


# 實現MyDatasets類
class MyDatasets(Dataset):

    def __init__(self, dir):
        # 獲取數據存放的dir
        # 例如d:/images/
        self.data_dir = dir
        # 用於存放(image,label) tuple的list,存放的數據例如(d:/image/1.png,4)
        self.image_target_list = []
        # 從dir--label的map文件中將所有的tuple對讀取到image_target_list中
        # map.txt中全部存放的是d:/.../image_data/1/3.jpg 1 路徑最好是絕對路徑
        with open(os.path.join(dir, 'map.txt'), 'r') as fp:
            content = fp.readlines()
            str_list = [s.rstrip().split() for s in content]
            # 將所有圖片的dir--label對都放入列表,如果要執行多個epoch,可以在這里多復制幾遍,然后統一shuffle比較好
            self.image_target_list = [(x[0], int(x[1])) for x in str_list]

    def __getitem__(self, index):
        image_label_pair = self.image_target_list[index]
        # 按path讀取圖片數據,並轉換為圖片格式例如[3,32,32]
        img = mpimg.imread(image_label_pair[0])
        return img, image_label_pair[1]

    def __len__(self):
        return len(self.image_target_list)


if __name__ == '__main__':
    # 生成map.txt
    # generate_map('train/')

    train_loader = DataLoader(MyDatasets('train/'), batch_size=128, shuffle=True)

    for step in range(20000):
        for idx, (img, label) in enumerate(train_loader):
            print(img.shape)
            print(label.shape)

上述代碼簡要說明了利用Datasets類和DataLoader類來讀取數據,本例用的是圖片原始數據,大概的結構如下:

如果使用其他形式的數據,例如二進制文件,則需要字節讀取文件,分割成每一張圖片和label,然后從__getitem__中返回就可以了。例如cifar-10數據,我們只需要在__getitem__方法中,按index來讀取對應位置的字節,然后轉換為label和img,並返回。在__len__中返回cifar-10訓練集的總樣本數。DataLoader就可以根據我們提供的index,len以及batch_size,shuffle來返回相應的batch數據和label。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM