寫到前面
這是torchvision.utils模塊里面的兩個方法,因為比較常用,所以pytorch直接封裝好了。
制作網格
網絡圖像一般用於訓練數據或測試數據的可視化。
torchvision.utils.make_grid(tensor, nrow, padding) → torch.Tensor
- 描述
將多張tensor格式的圖像以網格的方式封裝到一起。
- 參數
tensor (tensor or list):四維 (B x C x H x W) mini-batch的tensor數據或者是包含同一尺寸的圖片列表。
nrow (int):網格每行圖片的個數,默認是8;千萬不要理解為圖片的行數。
padding (int):四周填充的寬度,默認是2,你可以理解為網格中圖片之間的間距。默認填充值是0,也就是黑色。
注:這是三個比較常用的參數,其它參數請參考官方文檔。
- 示例
# 以mnist數據集為例,train_loader的batch_size設置為9
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
- 繪圖

保存本地
tensor數據類型保存時不用再轉為PIL.Image或numpy.ndarray,pytorch直接給我們寫好了一個方法。
torchvision.utils.save_image(tensor, fp) → None
- 描述
直接將tensor數據保存為圖像。
- 參數
tensor (Tensor or list):待保存的tensor數據。如果給以一個四維的mini-batch的tensor,將調用網格方法,然后再保存到本地。
fp (string or file object)):圖像的保存路徑。
注:這是兩個比較常用的參數,其它參數請參考官方文檔。
- 示例
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
torchvision.utils.save_image(images, 'test.jpg')
完整代碼
#%% 導入模塊
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
#%% 下載數據集
train_file = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
#%% 制作數據加載器
train_loader = DataLoader(
dataset=train_file,
batch_size=9,
shuffle=True
)
#%% 訓練數據可視化
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
save_image(images, 'test.jpg')
