【pytorch】帶batch的tensor類型圖像顯示


項目場景

pytorch訓練時我們一般把數據集放到數據加載器里,然后分批拿出來訓練。訓練前我們一般還要看一下訓練數據長啥樣,也就是訓練數據集可視化。那么如何顯示dataloader里面帶batchtensor類型的圖像呢?

顯示圖像

繪圖最常用的庫就是matplotlib

pip install matplotlib

顯示圖像會用到matplotlib.pyplot.imshow方法。查閱官方文檔可知,該方法接收的圖像的通道數要放到后面:
在這里插入圖片描述
數據加載器中數據的維度是[B, C, H, W],我們每次只拿一個數據出來就是[C, H, W],而matplotlib.pyplot.imshow要求的輸入維度是[H, W, C],所以我們需要交換一下數據維度,把通道數放到最后面,這里用到pytorch里面的permute方法(transpose方法也行,不過要交換兩次,沒這個方便,numpy中的transpose方法倒是可以一次交換完成),用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代碼示例

#%% 導入模塊
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下載數據集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作數據加載器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 訓練數據可視化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

這里以mnist數據集為例,演示一下顯示效果。我這個代碼其實還有一點小問題。數據增強的時候我不是進行標准化了嘛,就是在第7行代碼:Normalize((0.1307,), (0.3081,))。所以,如果你想查看訓練集的原始圖像,還得反標准化。

  • 標准化:image = (image-mean)/std
  • 反標准化:image = image*std+mean

我拿imagenet中的一個螞蟻和蜜蜂的子集做了一下實驗,標准化前后的區別還是很明顯的:
在這里插入圖片描述

最終效果

在這里插入圖片描述

引用參考

https://pytorch.org/docs/stable/tensors.html
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM