PyTorch DataLoader()使用


DataLoader的作用:通常在訓練時我們會將數據集分成若干小的、隨機的batch,這個操作當然可以手動操作,但是PyTorch里面為我們提供了API讓我們方便地從dataset中獲得batch,DataLoader就是干這事兒的。
先看官方文檔的描述,包括了每個參數的定義:

它的本質是一個可迭代對象,一般的操作是:

  1. 創建一個dataset對象
  2. 創建一個DataLoader對象
  3. 遍歷這個DataLoader對象,將data, label加載到模型中進行訓練
#一個粗略的示意
dataset = torchvision.datasets.MNIST()  #從torchvision這個包里獲得一個dataset對象
train_iter = torch.utils.data.DataLoader(dataset, batch_size = args.batch_size, shuffle = True)#創建DataLoader對象
for epoch in num(epochs):#將數據加載到模型之中
    for data, label in train_iter:
        ...

DataLoader還有更多的細節,但現在還沒有遇到,所以先記下這部分。
這個博客關於這個話題講得不錯,參考 https://www.cnblogs.com/ranjiewen/p/10128046.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM