torch.utils.data.DataLoader與迭代器轉換
在做實驗時,我們常常會使用用開源的數據集進行測試。而Pytorch中內置了許多數據集,這些數據集我們常常使用DataLoader
類進行加載。
如下面這個我們使用DataLoader
類加載torch.vision
中的FashionMNIST
數據集。
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.CIFAR10(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=ToTensor()
)
我們接下來定義Dataloader
對象用於加載這兩個數據集:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
那么這個train_dataloader
究竟是什么類型呢?
print(type(train_dataloader)) # <class 'torch.utils.data.dataloader.DataLoader'>
我們可以將先其轉換為迭代器類型。
print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>
然后再使用next(iter(train_dataloader))
從迭代器里取數據,如下所示:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0]
label = train_labels[0]
plt.imshow(torch.permute(img, (1, 2, 0)))
plt.show()
print(f"Label: {label}")
可以看到我們成功獲取了數據集中第一張圖片的信息,控制台打印:
Feature batch shape: torch.Size([64, 3, 32, 32])
Labels batch shape: torch.Size([64])
Label: 1
圖片可視化顯示如下:
PS: 事實上我們也可以直接索引
datasets
對象以訪問第1個樣本的圖片數據和標簽數據:
print(training_data[0][0].shape) # torch.Size([3, 32, 32])
print(training_data[0][1]) # 6
這里
training_data[0]
是第一個樣本對應的(圖片, 標簽)
元組。training_data[0][0]
是第一個樣本的圖片數據,training_data[0][1]
則是第一個樣本的標簽數據。如上所示,通過直接索引得到的單張圖片通道也在第一維。這種順序也就是經典的NCHW(N、C、H、W分別為批量、通道、圖片高度、圖片寬度維度)。Caffe的通道順序也是NCHW,但Tensorlfow1.*和Tensorlfow2.*的順序都為HWC順序嗎?那是什么時候變成的CHW順序呢?原來,這是因為我們在最開始時設置了transform=ToTensor()
,而ToTensor()
函數會在獲取圖片時將其順序變為CHW(關於ToTensor()
函數更多的作用,可參見我的博客《Pytorch:以單通道(灰度圖)加載圖片》)。此外,我們也訪問
datasets
對象的data
和target
屬性來分別獲得第1個樣本的圖片數據和標簽數據,如下列命令所示:
print(training_data.data[0].shape) # (32, 32, 3)
print(training_data.targets[0]) # 6
注意,此時
datasets
對象的data
屬性獲得的單張圖片中通道在最后一維! 如果再將這樣的圖片數據用Dataloader
對象進行迭代,則最后獲得的單張圖片中通道仍會保持在最后一維:
training_data = training_data.data
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
train_features = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}") # torch.Size([64, 32, 32, 3])
誒,我們不是設置了
ToTensor()
函數嗎?原來,ToTensor
函數是在對datasets
對象調用__getitem__
方法時觸發調用的(這里使用了惰性求值(lazy evaluation)的思想),__getitem__
方法大致如下所示:
def __getitem__(self, idx):
x, y = self.dataset[idx]
if self.transform:
x = self.transform(x)
return x, y
而我們使用
training_data.data[0]
獲取數據,是在對training_data.data
進行索引,而沒有對datasets
對象本身進行索引的操作,就不會去調用datasets
對象的__getitem__
方法,自然就不會進行圖片維度順序的轉換了。最后,這里再說過個題外話,這里Pytorch默認的三個通道像素順序為RGB,事實上PIL庫、Tensorflow1.*/Tensorflow2.*和我們日常圖片存儲的通道像素順序都是RGB,但並非所有軟件都是如此。例如OpenCV的通道像素順序就為BGR。
上面提到的這些點在做實驗時都需要額外注意。
接下來我們言歸正傳,接着來看DataLoader
迭代器。有讀者可能就會產生疑問,很多時候我們並沒有將DataLoader
類型強制轉換成迭代器類型呀,大多數時候我們會寫如下代碼:
for train_features, train_labels in train_dataloader:
print(train_features.shape) # torch.Size([64, 3, 32, 32])
print(train_features[0].shape) # torch.Size([3, 32, 32])
img = train_features[0]
label = train_labels[0]
plt.imshow(torch.permute(img, (1, 2, 0)))
plt.show()
print(f"Label: {label}")
可以看到,該代碼也能夠正常迭代訓練數據,前三個樣本的控制台打印輸出為:
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 6
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 7
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 9
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
那么為什么我們這里沒有顯式將Dataloader
轉換為迭代器類型呢,其實是Python語言for循環的一種機制,一旦我們用for ... in ...
句式來迭代一個對象,那么Python解釋器就會偷偷地自動幫我們創建好迭代器,也就是說
for train_features, train_labels in train_dataloader:
實際上等同於
for train_features, train_labels in iter(train_dataloader):
更進一步,這實際上等同於
train_iterator = iter(train_dataloader)
try:
while True:
train_features, train_labels = next(train_iterator)
except StopIteration:
pass
推而廣之,我們在用Python迭代直接迭代列表時:
for x in [1, 2, 3, 4]:
其實Python解釋器已經為我們隱式轉換為迭代器了:
list_iterator = iter([1, 2, 3, 4])
try:
while True:
x = next(list_iterator)
except StopIteration:
pass
參考
- [1] https://pytorch.org/
- [2] Martelli A, Ravenscroft A, Ascher D. Python cookbook[M]. " O'Reilly Media, Inc.", 2005.