DataLoader
的作用:通常在訓練時我們會將數據集分成若干小的、隨機的batch,這個操作當然可以手動操作,但是PyTorch里面為我們提供了API讓我們方便地從dataset中獲得batch,DataLoader
就是干這事兒的。
先看官方文檔的描述,包括了每個參數的定義:
它的本質是一個可迭代對象,一般的操作是:
- 創建一個
dataset
對象 - 創建一個
DataLoader
對象 - 遍歷這個
DataLoader
對象,將data
,label
加載到模型中進行訓練
#一個粗略的示意
dataset = torchvision.datasets.MNIST() #從torchvision這個包里獲得一個dataset對象
train_iter = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = True)#創建DataLoader對象
for epoch in num(epochs):#將數據加載到模型之中
for data, label in train_iter:
...
DataLoader
還有更多的細節,但現在還沒有遇到,所以先記下這部分。
這個博客關於這個話題講得不錯,參考 https://www.cnblogs.com/ranjiewen/p/10128046.html