pytorch中如何使用DataLoader對數據集進行批處理


最近搞了搞minist手寫數據集的神經網絡搭建,一個數據集里面很多個數據,不能一次喂入,所以需要分成一小塊一小塊喂入搭建好的網絡。

pytorch中有很方便的dataloader函數來方便我們進行批處理,做了簡單的例子,過程很簡單,就像把大象裝進冰箱里一共需要幾步?

 


 

第一步:打開冰箱門。

我們要創建torch能夠識別的數據集類型(pytorch中也有很多現成的數據集類型,以后再說)。

首先我們建立兩個向量X和Y,一個作為輸入的數據,一個作為正確的結果:

    

隨后我們需要把X和Y組成一個完整的數據集,並轉化為pytorch能識別的數據集類型:

    

我們來看一下這些數據的數據類型:

     

可以看出我們把X和Y通過Data.TensorDataset() 這個函數拼裝成了一個數據集,數據集的類型是【TensorDataset】。

好了,第一步結束了,冰箱門打開了。

 


 

第二步:把大象裝進去。

就是把上一步做成的數據集放入Data.DataLoader中,可以生成一個迭代器,從而我們可以方便的進行批處理。

     

DataLoader中也有很多其他參數:

dataset:Dataset類型,從其中加載數據 
batch_size:int,可選。每個batch加載多少樣本 
shuffle:bool,可選。為True時表示每個epoch都對數據進行洗牌 
sampler:Sampler,可選。從數據集中采樣樣本的方法。 
num_workers:int,可選。加載數據時使用多少子進程。默認值為0,表示在主進程中加載數據。 
collate_fn:callable,可選。 
pin_memory:bool,可選 
drop_last:bool,可選。True表示如果最后剩下不完全的batch,丟棄。False表示不丟棄。

好了,第二步結束了,大象裝進去了。

 


 

第三步:把冰箱門關上。

好啦,現在我們就可以愉快的用我們上面定義好的迭代器進行訓練啦。

在這里我們利用print來模擬我們的訓練過程,即我們在這里對搭建好的網絡進行喂入。

     

輸出的結果是:

      

可以看到,我們一共訓練了所有的數據訓練了5次。數據中一共10組,我們設置的mini-batch是3,即每一次我們訓練網絡的時候喂入3組數據,到了最后一次我們只有1組數據了,比mini-batch小,我們就僅輸出這一個。

此外,還可以利用python中的enumerate(),是對所有可以迭代的數據類型(含有很多東西的list等等)進行取操作的函數,用法如下:

     

 

好啦,現在冰箱門就關上啦,(*^__^*) 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM