Pytorch:torch.utils.data.DataLoader與迭代器轉換


torch.utils.data.DataLoader與迭代器轉換

在做實驗時,我們常常會使用用開源的數據集進行測試。而Pytorch中內置了許多數據集,這些數據集我們常常使用DataLoader類進行加載。
如下面這個我們使用DataLoader類加載torch.vision中的FashionMNIST數據集。

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

我們接下來定義Dataloader對象用於加載這兩個數據集:

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

那么這個train_dataloader究竟是什么類型呢?

print(type(train_dataloader))  # <class 'torch.utils.data.dataloader.DataLoader'>

我們可以將先其轉換為迭代器類型。

print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>

然后再使用next(iter(train_dataloader))從迭代器里取數據,如下所示:

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0]
label = train_labels[0]
plt.imshow(torch.permute(img, (1, 2, 0)))
plt.show()
print(f"Label: {label}")

可以看到我們成功獲取了數據集中第一張圖片的信息,控制台打印:

Feature batch shape: torch.Size([64, 3, 32, 32])
Labels batch shape: torch.Size([64])
Label: 1

圖片可視化顯示如下:
NLP多任務學習

PS: 事實上我們也可以直接索引datasets對象以訪問第1個樣本的圖片數據和標簽數據:

print(training_data[0][0].shape) # torch.Size([3, 32, 32])
print(training_data[0][1]) # 6

這里training_data[0]是第一個樣本對應的(圖片, 標簽)元組。training_data[0][0]是第一個樣本的圖片數據,training_data[0][1]則是第一個樣本的標簽數據。如上所示,通過直接索引得到的單張圖片通道也在第一維。這種順序也就是經典的NCHW(N、C、H、W分別為批量、通道、圖片高度、圖片寬度維度)。Caffe的通道順序也是NCHW,但Tensorlfow1.*和Tensorlfow2.*的順序都為HWC順序嗎?那是什么時候變成的CHW順序呢?原來,這是因為我們在最開始時設置了transform=ToTensor(),而 ToTensor()函數會在獲取圖片時將其順序變為CHW(關於ToTensor()函數更多的作用,可參見我的博客《Pytorch:以單通道(灰度圖)加載圖片》)。

此外,我們也訪問datasets對象的datatarget屬性來分別獲得第1個樣本的圖片數據和標簽數據,如下列命令所示:

print(training_data.data[0].shape) # (32, 32, 3)
print(training_data.targets[0]) # 6

注意,此時datasets對象的data屬性獲得的單張圖片中通道在最后一維! 如果再將這樣的圖片數據用Dataloader對象進行迭代,則最后獲得的單張圖片中通道仍會保持在最后一維:


training_data = training_data.data

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

train_features = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}") # torch.Size([64, 32, 32, 3])

誒,我們不是設置了ToTensor()函數嗎?原來,ToTensor函數是在對datasets對象調用__getitem__方法時觸發調用的(這里使用了惰性求值(lazy evaluation)的思想),__getitem__方法大致如下所示:

def __getitem__(self, idx):

    x, y = self.dataset[idx]
    
    if self.transform:
        x = self.transform(x)
    
    return x, y   

而我們使用training_data.data[0]獲取數據,是在對training_data.data進行索引,而沒有對datasets對象本身進行索引的操作,就不會去調用datasets對象的__getitem__方法,自然就不會進行圖片維度順序的轉換了。

最后,這里再說過個題外話,這里Pytorch默認的三個通道像素順序為RGB,事實上PIL庫、Tensorflow1.*/Tensorflow2.*和我們日常圖片存儲的通道像素順序都是RGB,但並非所有軟件都是如此。例如OpenCV的通道像素順序就為BGR。

上面提到的這些點在做實驗時都需要額外注意。

接下來我們言歸正傳,接着來看DataLoader迭代器。有讀者可能就會產生疑問,很多時候我們並沒有將DataLoader類型強制轉換成迭代器類型呀,大多數時候我們會寫如下代碼:

for train_features, train_labels in train_dataloader: 
    print(train_features.shape) # torch.Size([64, 3, 32, 32])
    print(train_features[0].shape) # torch.Size([3, 32, 32])
    
    img = train_features[0]
    label = train_labels[0]
    plt.imshow(torch.permute(img, (1, 2, 0)))
    plt.show()
    print(f"Label: {label}")

可以看到,該代碼也能夠正常迭代訓練數據,前三個樣本的控制台打印輸出為:

torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 6
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 7
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 9
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])

那么為什么我們這里沒有顯式將Dataloader轉換為迭代器類型呢,其實是Python語言for循環的一種機制,一旦我們用for ... in ...句式來迭代一個對象,那么Python解釋器就會偷偷地自動幫我們創建好迭代器,也就是說

for train_features, train_labels in train_dataloader:

實際上等同於

for train_features, train_labels in iter(train_dataloader):

更進一步,這實際上等同於

train_iterator = iter(train_dataloader)
try:
    while True:
        train_features, train_labels = next(train_iterator)
except StopIteration:
    pass

推而廣之,我們在用Python迭代直接迭代列表時:

for x in [1, 2, 3, 4]:

其實Python解釋器已經為我們隱式轉換為迭代器了:

list_iterator = iter([1, 2, 3, 4])
try:
    while True:
        x = next(list_iterator)
except StopIteration:
    pass

參考

  • [1] https://pytorch.org/
  • [2] Martelli A, Ravenscroft A, Ascher D. Python cookbook[M]. " O'Reilly Media, Inc.", 2005.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM