Dataset 是 PyTorch 中用來表示數據集的一個抽象類,我們的數據集可以用這個類來表示,至少需要覆寫下面兩個方法:
1)__len__:一般用來返回數據集大小。
2)__getitem__:實現這個方法后,可以通過下標的方式 dataset[i] 的來取得第 $i$ 個數據。
DataLoader 本質上就是一個 iterable(內部定義了 __iter__ 方法),__iter__ 被定義成生成器,使用 yield 來返回數據,
並利用多進程來加速 batch data 的處理,DataLoader 組裝好數據后返回的是 Tensor 類型的數據。
注意:DataLoader 是間接通過 Dataset 來獲得數據的,然后進行組裝成一個 batch 返回,因為采用了生成器,所以每次只會組裝
一個 batch 返回,不會一次性組裝好全部的 batch,所以 DataLoader 節省的是 batch 的內存,並不是指數據集的內存,數據集可
以一開始就全部加載到內存里,也可以分批加載,這取決於 Dataset 中 __init__ 函數的實現。
舉個例子:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
# 因為數據集比較小,所以全部加載到內存里了
data = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = data.shape[0]
self.x_data = torch.from_numpy(data[:,:-1])
self.y_data = torch.from_numpy(data[:,[-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, # 傳遞數據集
batch_size=32, # 小批量的數據大小,每次加載一batch數據
shuffle=True, # 打亂數據之間的順序
num_workers=2) # 使用多少個子進程來加載數據,默認為0, 代表使用主線程加載batch數據
for epoch in range(100): # 訓練 100 輪
for i, data in enumerate(train_loader, 0): # 每次惰性返回一個 batch 數據
iuputs, label = data
...
