實際上pytorch在定義dataloader的時候是需要傳入很多參數的,比如,number_workers, pin_memory, 以及shuffle, dataset等,其中sampler參數算是其一
sampler實際上定義了torch.utils.data.dataloader的數據取樣方式,什么意思呢?
在自己定義dataset中的__getitem__函數的時候,每一個index,唯一的對應一個樣本,sampler實際上就是一系列的index組成的可迭代對象
如下圖所示的__iter__函數返回的可迭代對象
下圖所示的是randomsampler,即隨機的shuffle圖片的index,然后取樣,關鍵就一句話,在__iter__中的torch.randperm(n).tolist()
表明產生了一個0到n-1的一個list
比如我的數據是128張圖片,然后,dataset中的__len__也是128,我的sampler,如果不shuffle的話,其中的index從0-127沒毛病
如果將這個順序打亂,那就是相當於隨機取樣,和上圖一樣
也就是說,每個圖片都定義了唯一的一個index,取圖的時候按照sampler定義的規則來取圖,實際上這樣就可以做一些有意思的事情了
比如我的batch size是2,我想每一個batch取的圖片是第一張和緊挨着的后面的一張圖,假設sampler不shuffle的話,那么__iter__返回的可迭代對象應該是
iter([0,1,1,2,2,3,3,4,4,5,.....])沒毛病
再比如,我想我的batch是4,隔一張取一下圖片,那么我的sampler的函數返回的__iter__應該是iter[0,2,4,6,1,3,5,7,2,4,6,8,.......]
但是這又有什么用?比如按照我上面的第一種取圖的需求,我完全可以在__getitem__中定義下一個index使得和上面一個讀到的圖像一樣。用處就在這里,
假設是定義一個相同的index,讀取相同的圖片是非常的占用內存的,比如imagenet,讀完放到內存里面,大概是需要一百多個g,按照我剛剛的第一個例子讀取,內存就需要2倍,實際上這是不允許的,通過定義sampler,對於一張圖片重復采樣,比對於一張圖片讀取多遍顯然要划算的多