DataLoader
DataLoader 是 PyTorch 中讀取數據的一個重要接口,該接口定義在 dataloader.py 文件中,該接口的目的: 將自定義的 Dataset 根據 batch size 的大小、是否 shuffle 等封裝成一個 batch size 大小的 Tensor,用於后面的訓練。
通過 DataLoader,使得我們在准備 mini-batch 時可以多線程並行處理,這樣可以加快准備數據的速度。
DataLoader 是一個高效、簡潔、直觀地網絡輸入數據結構,便於使用和擴展
- DataLoader 本質是一個可迭代對象,使用 iter() 訪問,不能使用 next() 訪問
- 使用 iter(dataloader) 返回的是一個迭代器,然后使用 next() 訪問
- 也可以使用
for features, targets in dataloaders
進行可迭代對象的訪問- 一般我們實現一個 datasets 對象,傳入到 DataLoader 中,然后內部使用
yield
返回每一次 batch 的數據
DataLoader(object) 的部分參數:
# 傳入的數據集
dataset(Dataset)
# 每個 batch 有多少個樣本
batch_size(int, optional)
# 在每個 epoch 開始的時候,對數據進行重新排序
shuffle(bool, optional)
# 自定義從數據集中抽取樣本的策略,如果指定這個參數,那么 shuffle 必須為 False
sampler(Sampler, optional)
# 與 sampler 類似,但是一次只返回一個 batch 的 indices(索引),如果指定這個參數,那么 batch_size, shuffle, sampler, drop_last 就不能再指定了
batch_sampler(Sampler, optional)
# 這個參數決定有多少進程處理數據加載,0 意味着所有數據都會被加載到主進程,默認為0
num_workers(int, optional)
# 如果設置為 True,則最后不足batch_size大小的數據會被丟棄,比如batch_size=64, 而一個epoch只有100個樣本,則最后36個會被丟棄;如果設置為False,則最后的batch_size會小一點
drop_last(bool, optional)
Reference: