【深度學習】PyTorch Dataset類的使用與實例分析


Dataset類

介紹

當我們得到一個數據集時,Dataset類可以幫我們提取我們需要的數據,我們用子類繼承Dataset類,我們先給每個數據一個編號(idx),在后面的神經網絡中,初始化Dataset子類實例后,就可以通過這個編號去實例對象中讀取相應的數據,會自動調用__getitem__方法,同時子類對象也會獲取相應真實的Label(人為去復寫即可)

Dataset類的作用:提供一種方式去獲取數據及其對應的真實Label

在Dataset類的子類中,應該有以下函數以實現某些功能:

  1. 獲取每一個數據及其對應的Label
  2. 統計數據集中的數據數量

關於2,神經網絡經常需要對一個數據迭代多次,只有知道當前有多少個數據,進行訓練時才知道要訓練多少次,才能把整個數據集迭代完

Dataset官方文檔解讀

首先看一下Dataset的官方文檔解釋

導入Dataset類:

from torch.utils.data import Dataset

我們可以通過在Jupyter中查看官方文檔

from torch.utils.data import Dataset
help(Dataset)

輸出:

Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.
 |  
 |  .. note::
 |    :class:`~torch.utils.data.DataLoader` by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.
 |  
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |  
 |  __getattr__(self, attribute_name)
 |  
 |  __getitem__(self, index) -> +T_co
 |  
 |  ----------------------------------------------------------------------
 |  Class methods defined here:
 |  
 |  register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta
 |  
 |  register_function(function_name, function) from typing.GenericMeta
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __abstractmethods__ = frozenset()
 |  
 |  __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
 |  
 |  __args__ = None
 |  
 |  __extra__ = None
 |  
 |  __next_in_mro__ = <class 'object'>
 |      The most base type
 |  
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |  
 |  __origin__ = None
 |  
 |  __parameters__ = (+T_co,)
 |  
 |  __tree_hash__ = -9223371872509358054
 |  
 |  functions = {'concat': functools.partial(<function Dataset.register_da...
 |  
 |  ----------------------------------------------------------------------
 |  Static methods inherited from typing.Generic:
 |  
 |  __new__(cls, *args, **kwds)
 |      Create and return a new object.  See help(type) for accurate signature.

還有一種方式獲取官方文檔信息:

Dataset??

輸出:

Init signature: Dataset(*args, **kwds)
Source:        
class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    functions: Dict[str, Callable] = {}

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

    def __getattr__(self, attribute_name):
        if attribute_name in Dataset.functions:
            function = functools.partial(Dataset.functions[attribute_name], self)
            return function
        else:
            raise AttributeError

    @classmethod
    def register_function(cls, function_name, function):
        cls.functions[function_name] = function

    @classmethod
    def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
        if function_name in cls.functions:
            raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))

        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
            result_pipe = cls(source_dp, *args, **kwargs)
            if isinstance(result_pipe, Dataset):
                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
                        result_pipe = result_pipe.trace_as_dataframe()

            return result_pipe

        function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
        cls.functions[function_name] = function
File:           d:\environment\anaconda3\envs\py-torch\lib\site-packages\torch\utils\data\dataset.py
Type:           GenericMeta
Subclasses:     Dataset, IterableDataset, Dataset, TensorDataset, ConcatDataset, Subset, Dataset, Subset, Dataset, IterableDataset[+T_co], ...

其中我們可以看到:

"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.
    
    """

以上內容顯示:

該類是一個抽象類,所有的數據集想要在數據與標簽之間建立映射,都需要繼承這個類,所有的子類都需要重寫__getitem__方法,該方法根據索引值獲取每一個數據並且獲取其對應的Label,子類也可以重寫__len__方法,返回數據集的size大小

實例:GetData類

准備工作

首先我們創建一個類,類名為GetData,這個類要繼承Dataset類

class GetData(Dataset):

一般在類中首先需要寫的是__init__方法,此方法用於對象實例化,通常用來提供類中需要使用的變量,可以先不寫

class GetData(Dataset):
    def __init__(self):
        pass

我們可以先寫__getitem__方法:

class GetData(Dataset):
    def __init__(self):
        pass
    
    def __getitem__(self, idx):  # 默認是item,但常改為idx,是index的縮寫
        pass

其中,idx是index的簡稱,就是一個編號,以便以后數據集獲取后,我們使用索引編號訪問每個數據

在實現GetData類之前,我們首先需要解決的問題就是如何讀取一個圖像數據,通常我們使用PIL來讀取

PIL獲取圖像數據

我們使用PIL來讀取數據,它提供一個Image模塊,可以讓我們提取圖像數據,我們先導入這個模塊

from PIL import Image

我們可以在Python Console中看看如何使用 Image

在Python Console中,輸入代碼:

from PIL import Image

將數據集放入項目文件夾,我們需要獲取圖片的絕對路徑,選中具體的圖片,右鍵選擇Copy Path,然后選擇 Absolute path(快捷鍵:Ctrl + Shift + C)

img_path = "D:\\DeepLearning\\dataset\\train\\ants\\0013035.jpg"

在Windows下,路徑分割需要是\\,來表示轉譯

也可以在字符串前面加 r 防轉譯

使用Image的open方法讀取圖片:

img = Image.open(img_path)

可以在Python控制台看到讀取出來的 img,是一個JpegImageFile類的對象

在圖中,可以看到這個對象的一些屬性,比如size

我們查看這個屬性的內容,輸入以下代碼:

img.size

輸出:

(768, 512)

我們可以看到此圖的寬是768,高是512,__len__表示的是這個size元組的長度,有兩個值,所以為 2

show方法顯示圖片:

img.show()

獲取圖片的文件名

從數據集路徑中,獲取所有文件的名字,存儲到一個列表中

一個簡單的例子(在Python Console中):

我們需要借助os模塊

import os
dir_path = "dataset/train/ants_image"
img_path_list = os.listdir(dir_path)

listdir方法會將路徑下的所有文件名(包括后綴名)組成一個列表

我們可以使用索引去訪問列表中的每個文件名

img_path_list[0]
Out[14]: '0013035.jpg'

構建數據集路徑

我們需要搭建數據集的路徑表示,一個根目錄路徑和一個具體的子目錄路徑,以作為不同數據集的區分

一個簡單的案例,在Python Console中輸入:

root_dir = "dataset/train"
child_dir = "ants_image"

我們使用os.path.join方法,將兩個路徑拼接起來,就得到了ants子數據集的相對路徑

path = os.path.join(root_dir, child_dir)

path的值此時是:

path={str}'dataset/train\\ants_image'

我們有了這個數據集的路徑后,就可以使用之前所講的listdir方法,獲取這個路徑中所有文件的文件名,存儲到一個列表中

img_path_list = os.listdir(path)
idx = 0
img_path_list[idx]
Out[21]: '0013035.jpg'

可以看到結果與我們之前的小案例是一樣的

有了具體的名字,我們還可以將這個文件名與路徑進行組合,然后使用PIL獲取具體的圖像img對象

img_name = img_path_list[idx]
img_item_path = os.path.join(root_dir, child_dir, img_name)
img = Image.open(img_item_path)

在掌握了如何組裝路徑、獲取路徑中的文件名以及獲取具體圖像對象后,我們可以完善我們的__init____getitem__方法了

完善__init__方法

在init中為啥使用self:一個函數中的變量是不能拿到另外一個函數中使用的,self可以當做類中的全局變量

class GetData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)

很簡單,就是接收實例化時傳入的參數:獲取根目錄路徑、子目錄路徑

然后將兩個路徑進行組合,就得到了目標數據集的路徑

我們將這個路徑作為參數傳入listdir函數,從而讓img_path_list中存儲該目錄下所有文件名(包含后綴名)

此時通過索引就可以輕松獲取每個文件名

接下來,我們要使用這些初始化的信息去獲取其中的每一個圖片的JpegImageFile對象

完善__getitem__方法

我們在初始化中,已經通過組裝數據集路徑,進而通過listdir方法獲取了數據集中每個文件的文件名,存入了一個列表中。

在__getitem__方法中,默認會有一個 item 參數,常命名為 idx,這個參數是一個索引編號,用於對我們初始化中得到的文件名列表進行索引訪問,我們就得到了具體的文件名,然后與根目錄、子目錄再次組裝,得到具體數據的相對路徑,我們可以通過這個路徑獲取到索引編號對應的數據對象本身。

這樣巧妙的讓索引與數據集中的具體數據對應了起來

def __getitem__(self, idx):
    img_name = self.img_path_list[idx]  # 從文件名列表中獲取了文件名
    img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 組裝路徑,獲得了圖片具體的路徑

獲取了具體的圖像路徑后,我們需要使用PIL讀取這個圖像

def __getitem__(self, idx):
    img_name = self.img_path[idx]  
    img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) 
    img = Image.open(img_item_path)
    label = self.label_dir
    return img, label

此處img是一個JpegImageFile對象,label是一個字符串

自此,這個函數我們就實現完成了

以后使用這個類進行實例化時,傳入的參數是根目錄路徑,以及對應的label名,我們就可以得到一個GetData對象。

有了這個GetData對象后,我們可以直接使用索引來獲取具體的圖像對象(類:JpegImageFile),因為__getitem__方法已經幫我們實現了,我們只需要使用索引即可調用__getitem__方法,會返回我們根據索引提取到的對應數據的圖像對象以及其label

root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GetData(root_dir, bees_label_dir)
img1, label1 = ants_dataset[0]  # 返回一個元組,返回值是__getitem__方法的返回值 
img2, label2 = bees_dataset[0]

完善__len__方法

__len__實現很簡單

主要功能是獲取數據集的長度,由於我們在初始化中已經獲取了所有文件名的列表,所以只需要知道這個列表的長度,就知道了有多少個文件,也就是知道了有多少個具體的數據

def __len__(self):
    return len(self.img_path_list)

組合數據集

我們還可以將兩個數據集對象進行組合,組合成一個大的數據集對象

train_dataset = ants_dataset + bees_dataset

我們看看這三個數據集對象的大小(在python Console中):

len1 = len(ants_dataset)
len2 = len(bees_dataset)
len3 = len(train_dataset)

輸出:

124
121
245

我們可以看到剛好 $$124 + 121 = 245$$

而對這個組合的數據集的訪問也很有意思,也同樣是使用索引,0 ~ 123 都是ants數據集的內容,124 - 244 都是bees數據集的內容

img1, label1 = train_dataset[123]
img1.show()
img2, label2 = train_dataset[124]
img2.show()

完整代碼

from torch.utils.data import Dataset
from PIL import Image
import os

class GetData(Dataset):

    # 初始化為整個class提供全局變量,為后續方法提供一些量
    def __init__(self, root_dir, label_dir):

        # self
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path_list[idx]  # 只獲取了文件名
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每個圖片的位置
        # 讀取圖片
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GeyData(root_dir, bees_label_dir)
img, lable = ants_dataset[0] # 返回一個元組,返回值就是__getitem__的返回值


# 獲取整個訓練集,就是對兩個數據集進行了拼接
train_dataset = ants_dataset + bees_dataset

len1 = len(ants_dataset)  # 124
len2 = len(bees_dataset)  # 121
len = len(train_dataset) # 245

img1, label1 = train_dataset[123]  # 獲取的是螞蟻的最后一個
img2, label2 = train_dataset[124]  # 獲取的是蜜蜂第一個


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM