Dataset 是 PyTorch 中用來表示數據集的一個抽象類,我們的數據集可以用這個類來表示,至少需要覆寫下面兩個方法:
1)__len__:一般用來返回數據集大小。
2)__getitem__:實現這個方法后,可以通過下標的方式 dataset[i] 的來取得第 $i$ 個數據。
DataLoader 本質上就是一個 iterable(內部定義了 __iter__ 方法),__iter__ 被定義成生成器,使用 yield 來返回數據,
並利用多進程來加速 batch data 的處理,DataLoader 組裝好數據后返回的是 Tensor 類型的數據。
注意:DataLoader 是間接通過 Dataset 來獲得數據的,然后進行組裝成一個 batch 返回,因為采用了生成器,所以每次只會組裝
一個 batch 返回,不會一次性組裝好全部的 batch,所以 DataLoader 節省的是 batch 的內存,並不是指數據集的內存,數據集可
以一開始就全部加載到內存里,也可以分批加載,這取決於 Dataset 中 __init__ 函數的實現。
舉個例子:
import torch import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader class DiabetesDataset(Dataset): def __init__(self, filepath): # 因為數據集比較小,所以全部加載到內存里了 data = np.loadtxt(filepath, delimiter=',', dtype=np.float32) self.len = data.shape[0] self.x_data = torch.from_numpy(data[:,:-1]) self.y_data = torch.from_numpy(data[:,[-1]]) def __getitem__(self, index): return self.x_data[index], self.y_data[index] def __len__(self): return self.len dataset = DiabetesDataset('diabetes.csv.gz') train_loader = DataLoader(dataset=dataset, # 傳遞數據集 batch_size=32, # 小批量的數據大小,每次加載一batch數據 shuffle=True, # 打亂數據之間的順序 num_workers=2) # 使用多少個子進程來加載數據,默認為0, 代表使用主線程加載batch數據 for epoch in range(100): # 訓練 100 輪 for i, data in enumerate(train_loader, 0): # 每次惰性返回一個 batch 數據 iuputs, label = data ...