Dataset類
介紹
當我們得到一個數據集時,Dataset類可以幫我們提取我們需要的數據,我們用子類繼承Dataset類,我們先給每個數據一個編號(idx),在后面的神經網絡中,初始化Dataset子類實例后,就可以通過這個編號去實例對象中讀取相應的數據,會自動調用__getitem__方法,同時子類對象也會獲取相應真實的Label(人為去復寫即可)
Dataset類的作用:提供一種方式去獲取數據及其對應的真實Label
在Dataset類的子類中,應該有以下函數以實現某些功能:
- 獲取每一個數據及其對應的Label
- 統計數據集中的數據數量
關於2,神經網絡經常需要對一個數據迭代多次,只有知道當前有多少個數據,進行訓練時才知道要訓練多少次,才能把整個數據集迭代完
Dataset官方文檔解讀
首先看一下Dataset的官方文檔解釋
導入Dataset類:
from torch.utils.data import Dataset
我們可以通過在Jupyter中查看官方文檔
from torch.utils.data import Dataset
help(Dataset)
輸出:
Help on class Dataset in module torch.utils.data.dataset:
class Dataset(typing.Generic)
| An abstract class representing a :class:`Dataset`.
|
| All datasets that represent a map from keys to data samples should subclass
| it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
| data sample for a given key. Subclasses could also optionally overwrite
| :meth:`__len__`, which is expected to return the size of the dataset by many
| :class:`~torch.utils.data.Sampler` implementations and the default options
| of :class:`~torch.utils.data.DataLoader`.
|
| .. note::
| :class:`~torch.utils.data.DataLoader` by default constructs a index
| sampler that yields integral indices. To make it work with a map-style
| dataset with non-integral indices/keys, a custom sampler must be provided.
|
| Method resolution order:
| Dataset
| typing.Generic
| builtins.object
|
| Methods defined here:
|
| __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]'
|
| __getattr__(self, attribute_name)
|
| __getitem__(self, index) -> +T_co
|
| ----------------------------------------------------------------------
| Class methods defined here:
|
| register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta
|
| register_function(function_name, function) from typing.GenericMeta
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| __abstractmethods__ = frozenset()
|
| __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
|
| __args__ = None
|
| __extra__ = None
|
| __next_in_mro__ = <class 'object'>
| The most base type
|
| __orig_bases__ = (typing.Generic[+T_co],)
|
| __origin__ = None
|
| __parameters__ = (+T_co,)
|
| __tree_hash__ = -9223371872509358054
|
| functions = {'concat': functools.partial(<function Dataset.register_da...
|
| ----------------------------------------------------------------------
| Static methods inherited from typing.Generic:
|
| __new__(cls, *args, **kwds)
| Create and return a new object. See help(type) for accurate signature.
還有一種方式獲取官方文檔信息:
Dataset??
輸出:
Init signature: Dataset(*args, **kwds)
Source:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
functions: Dict[str, Callable] = {}
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
def __getattr__(self, attribute_name):
if attribute_name in Dataset.functions:
function = functools.partial(Dataset.functions[attribute_name], self)
return function
else:
raise AttributeError
@classmethod
def register_function(cls, function_name, function):
cls.functions[function_name] = function
@classmethod
def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
if function_name in cls.functions:
raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
result_pipe = cls(source_dp, *args, **kwargs)
if isinstance(result_pipe, Dataset):
if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
if function_name not in UNTRACABLE_DATAFRAME_PIPES:
result_pipe = result_pipe.trace_as_dataframe()
return result_pipe
function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
cls.functions[function_name] = function
File: d:\environment\anaconda3\envs\py-torch\lib\site-packages\torch\utils\data\dataset.py
Type: GenericMeta
Subclasses: Dataset, IterableDataset, Dataset, TensorDataset, ConcatDataset, Subset, Dataset, Subset, Dataset, IterableDataset[+T_co], ...
其中我們可以看到:
"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
"""
以上內容顯示:
該類是一個抽象類,所有的數據集想要在數據與標簽之間建立映射,都需要繼承這個類,所有的子類都需要重寫__getitem__
方法,該方法根據索引值獲取每一個數據並且獲取其對應的Label,子類也可以重寫__len__
方法,返回數據集的size大小
實例:GetData類
准備工作
首先我們創建一個類,類名為GetData,這個類要繼承Dataset類
class GetData(Dataset):
一般在類中首先需要寫的是__init__
方法,此方法用於對象實例化,通常用來提供類中需要使用的變量,可以先不寫
class GetData(Dataset):
def __init__(self):
pass
我們可以先寫__getitem__
方法:
class GetData(Dataset):
def __init__(self):
pass
def __getitem__(self, idx): # 默認是item,但常改為idx,是index的縮寫
pass
其中,idx是index的簡稱,就是一個編號,以便以后數據集獲取后,我們使用索引編號訪問每個數據
在實現GetData類之前,我們首先需要解決的問題就是如何讀取一個圖像數據,通常我們使用PIL來讀取
PIL獲取圖像數據
我們使用PIL來讀取數據,它提供一個Image模塊,可以讓我們提取圖像數據,我們先導入這個模塊
from PIL import Image
我們可以在Python Console中看看如何使用 Image
在Python Console中,輸入代碼:
from PIL import Image
將數據集放入項目文件夾,我們需要獲取圖片的絕對路徑,選中具體的圖片,右鍵選擇Copy Path,然后選擇 Absolute path(快捷鍵:Ctrl + Shift + C)
img_path = "D:\\DeepLearning\\dataset\\train\\ants\\0013035.jpg"
在Windows下,路徑分割需要是
\\
,來表示轉譯也可以在字符串前面加
r
防轉譯
使用Image的open方法讀取圖片:
img = Image.open(img_path)
可以在Python控制台看到讀取出來的 img,是一個JpegImageFile類的對象
在圖中,可以看到這個對象的一些屬性,比如size
我們查看這個屬性的內容,輸入以下代碼:
img.size
輸出:
(768, 512)
我們可以看到此圖的寬是768,高是512,__len__
表示的是這個size元組的長度,有兩個值,所以為 2
show方法顯示圖片:
img.show()
獲取圖片的文件名
從數據集路徑中,獲取所有文件的名字,存儲到一個列表中
一個簡單的例子(在Python Console中):
我們需要借助os模塊
import os
dir_path = "dataset/train/ants_image"
img_path_list = os.listdir(dir_path)
listdir方法會將路徑下的所有文件名(包括后綴名)組成一個列表
我們可以使用索引去訪問列表中的每個文件名
img_path_list[0]
Out[14]: '0013035.jpg'
構建數據集路徑
我們需要搭建數據集的路徑表示,一個根目錄路徑和一個具體的子目錄路徑,以作為不同數據集的區分
一個簡單的案例,在Python Console中輸入:
root_dir = "dataset/train"
child_dir = "ants_image"
我們使用os.path.join
方法,將兩個路徑拼接起來,就得到了ants子數據集的相對路徑
path = os.path.join(root_dir, child_dir)
path的值此時是:
path={str}'dataset/train\\ants_image'
我們有了這個數據集的路徑后,就可以使用之前所講的listdir方法,獲取這個路徑中所有文件的文件名,存儲到一個列表中
img_path_list = os.listdir(path)
idx = 0
img_path_list[idx]
Out[21]: '0013035.jpg'
可以看到結果與我們之前的小案例是一樣的
有了具體的名字,我們還可以將這個文件名與路徑進行組合,然后使用PIL獲取具體的圖像img對象
img_name = img_path_list[idx]
img_item_path = os.path.join(root_dir, child_dir, img_name)
img = Image.open(img_item_path)
在掌握了如何組裝路徑、獲取路徑中的文件名以及獲取具體圖像對象后,我們可以完善我們的__init__
與__getitem__
方法了
完善__init__方法
在init中為啥使用self:一個函數中的變量是不能拿到另外一個函數中使用的,self可以當做類中的全局變量
class GetData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
很簡單,就是接收實例化時傳入的參數:獲取根目錄路徑、子目錄路徑
然后將兩個路徑進行組合,就得到了目標數據集的路徑
我們將這個路徑作為參數傳入listdir函數,從而讓img_path_list中存儲該目錄下所有文件名(包含后綴名)
此時通過索引就可以輕松獲取每個文件名
接下來,我們要使用這些初始化的信息去獲取其中的每一個圖片的JpegImageFile對象
完善__getitem__方法
我們在初始化中,已經通過組裝數據集路徑,進而通過listdir方法獲取了數據集中每個文件的文件名,存入了一個列表中。
在__getitem__方法中,默認會有一個 item 參數,常命名為 idx,這個參數是一個索引編號,用於對我們初始化中得到的文件名列表進行索引訪問,我們就得到了具體的文件名,然后與根目錄、子目錄再次組裝,得到具體數據的相對路徑,我們可以通過這個路徑獲取到索引編號對應的數據對象本身。
這樣巧妙的讓索引與數據集中的具體數據對應了起來
def __getitem__(self, idx):
img_name = self.img_path_list[idx] # 從文件名列表中獲取了文件名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 組裝路徑,獲得了圖片具體的路徑
獲取了具體的圖像路徑后,我們需要使用PIL讀取這個圖像
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
此處img是一個JpegImageFile對象,label是一個字符串
自此,這個函數我們就實現完成了
以后使用這個類進行實例化時,傳入的參數是根目錄路徑,以及對應的label名,我們就可以得到一個GetData對象。
有了這個GetData對象后,我們可以直接使用索引來獲取具體的圖像對象(類:JpegImageFile),因為__getitem__方法已經幫我們實現了,我們只需要使用索引即可調用__getitem__方法,會返回我們根據索引提取到的對應數據的圖像對象以及其label
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GetData(root_dir, bees_label_dir)
img1, label1 = ants_dataset[0] # 返回一個元組,返回值是__getitem__方法的返回值
img2, label2 = bees_dataset[0]
完善__len__方法
__len__實現很簡單
主要功能是獲取數據集的長度,由於我們在初始化中已經獲取了所有文件名的列表,所以只需要知道這個列表的長度,就知道了有多少個文件,也就是知道了有多少個具體的數據
def __len__(self):
return len(self.img_path_list)
組合數據集
我們還可以將兩個數據集對象進行組合,組合成一個大的數據集對象
train_dataset = ants_dataset + bees_dataset
我們看看這三個數據集對象的大小(在python Console中):
len1 = len(ants_dataset)
len2 = len(bees_dataset)
len3 = len(train_dataset)
輸出:
124
121
245
我們可以看到剛好 $$124 + 121 = 245$$
而對這個組合的數據集的訪問也很有意思,也同樣是使用索引,0 ~ 123 都是ants數據集的內容,124 - 244 都是bees數據集的內容
img1, label1 = train_dataset[123]
img1.show()
img2, label2 = train_dataset[124]
img2.show()
完整代碼
from torch.utils.data import Dataset
from PIL import Image
import os
class GetData(Dataset):
# 初始化為整個class提供全局變量,為后續方法提供一些量
def __init__(self, root_dir, label_dir):
# self
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path_list[idx] # 只獲取了文件名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每個圖片的位置
# 讀取圖片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GeyData(root_dir, bees_label_dir)
img, lable = ants_dataset[0] # 返回一個元組,返回值就是__getitem__的返回值
# 獲取整個訓練集,就是對兩個數據集進行了拼接
train_dataset = ants_dataset + bees_dataset
len1 = len(ants_dataset) # 124
len2 = len(bees_dataset) # 121
len = len(train_dataset) # 245
img1, label1 = train_dataset[123] # 獲取的是螞蟻的最后一個
img2, label2 = train_dataset[124] # 獲取的是蜜蜂第一個