Dataset 和 DataLoader 詳解


Dataset 是 PyTorch 中用來表示數據集的一個抽象類,我們的數據集可以用這個類來表示,至少需要覆寫下面兩個方法:

    1)__len__:一般用來返回數據集大小。

    2)__getitem__:實現這個方法后,可以通過下標的方式 dataset[i] 的來取得第 $i$ 個數據。

DataLoader 本質上就是一個 iterable(內部定義了 __iter__ 方法),__iter__ 被定義成生成器,使用 yield 來返回數據,

並利用多進程來加速 batch data 的處理,DataLoader 組裝好數據后返回的是 Tensor 類型的數據。

注意:DataLoader 是間接通過 Dataset 來獲得數據的,然后進行組裝成一個 batch 返回,因為采用了生成器,所以每次只會組裝

一個 batch 返回,不會一次性組裝好全部的 batch,所以 DataLoader 節省的是 batch 的內存,並不是指數據集的內存,數據集可

以一開始就全部加載到內存里,也可以分批加載,這取決於 Dataset 中 __init__ 函數的實現。

舉個例子:

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        # 因為數據集比較小,所以全部加載到內存里了
        data = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = data.shape[0]
        self.x_data = torch.from_numpy(data[:,:-1])
        self.y_data = torch.from_numpy(data[:,[-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,   # 傳遞數據集
                          batch_size=32,     # 小批量的數據大小,每次加載一batch數據
                          shuffle=True,      # 打亂數據之間的順序
                          num_workers=2)     # 使用多少個子進程來加載數據,默認為0, 代表使用主線程加載batch數據

for epoch in range(100):  # 訓練 100 輪
    for i, data in enumerate(train_loader, 0):  # 每次惰性返回一個 batch 數據
        iuputs, label = data
        ...

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM