pytorch 的數據加載到模型的操作順序如下:
- 創建一個 Dataset 對象
- 創建一個 DataLoader對象
- 循環這個 DataLoader對象,將data, label加載到模型中進行訓練
torch.utils.data.Dataset
pytorch中文文檔中的torch.utils.data
表示Dataset的抽象類。所有其他數據集都應該進行子類化。所有子類應該override__len__和__getitem__,前者提供了數據集的大小,后者支持整數索引,范圍從0到len(self)。
源碼參考點擊
class Dataset(object):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
torch.utils.data.DataLoader
參考:【pytorch】torch.utils.data.DataLoader
源碼參考點擊
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
參考今夜無風的博客 和 pytorch之dataloader深入剖析
數據加載器,結合了數據集和取樣器,並且可以提供多個線程處理數據集。在訓練模型時使用到此函數,用來把訓練數據分成多個小組,此函數每次拋出一組數據。直至把所有的數據都拋出。就是做一個數據的初始化。
DataLoader本質是一個可迭代對象,使用iter()訪問,不能使用next()訪問;使用iter(DataLoader)返回的是一個迭代器,然后可以使用next訪問。
DataLoader本質上就是一個iterable(跟python的內置類型list等一樣),並利用多進程來加速batch data的處理,使用yield來使用有限的內存。
語音自定義數據集
參考
- 博客園文章:pytorch加載語音類自定義數據集
