項目場景
pytorch
訓練時我們一般把數據集放到數據加載器里,然后分批拿出來訓練。訓練前我們一般還要看一下訓練數據長啥樣,也就是訓練數據集可視化。那么如何顯示dataloader
里面帶batch
的tensor
類型的圖像呢?
顯示圖像
繪圖最常用的庫就是matplotlib
:
pip install matplotlib
顯示圖像會用到matplotlib.pyplot.imshow
方法。查閱官方文檔可知,該方法接收的圖像的通道數要放到后面:
數據加載器中數據的維度是[B, C, H, W]
,我們每次只拿一個數據出來就是[C, H, W]
,而matplotlib.pyplot.imshow
要求的輸入維度是[H, W, C]
,所以我們需要交換一下數據維度,把通道數放到最后面,這里用到pytorch
里面的permute
方法(transpose
方法也行,不過要交換兩次,沒這個方便,numpy
中的transpose
方法倒是可以一次交換完成),用法示例如下:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])
代碼示例
#%% 導入模塊
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下載數據集
train_file = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
#%% 制作數據加載器
train_loader = DataLoader(
dataset=train_file,
batch_size=9,
shuffle=True
)
#%% 訓練數據可視化
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.title(labels[i].item())
plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.show()
這里以mnist
數據集為例,演示一下顯示效果。我這個代碼其實還有一點小問題。數據增強的時候我不是進行標准化了嘛,就是在第7行代碼:Normalize((0.1307,), (0.3081,))
。所以,如果你想查看訓練集的原始圖像,還得反標准化。
- 標准化:
image = (image-mean)/std
- 反標准化:
image = image*std+mean
我拿imagenet
中的一個螞蟻和蜜蜂的子集做了一下實驗,標准化前后的區別還是很明顯的:
最終效果
引用參考
https://pytorch.org/docs/stable/tensors.html
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html