np.transpose(np_image, [1, 2, 0])
pytorch中讀入圖片並進行顯示時
# visualization of an example of training data def show_image(tensor_image): np_image = tensor_image.numpy() np_image = np.transpose(np_image, [1, 2, 0])*0.5 + 0.5 # 轉置后做逆歸一化 plt.imshow(np_image)
plt.show() X = iter(train_loader).next()[0] print(X.size()) show_image(X)
其中有一行命令用來轉置
np.transpose(np_image, [1, 2, 0])
主要是Pytorch中使用的數據格式與plt.imshow()函數的格式不一致
Pytorch中為[Channels, H, W]
而plt.imshow()中則是[H, W, Channels]
因此,要先轉置一下。
該函數的解釋見:plt.imshow()
pytorch讀入並顯示圖片的方法
方式一
將讀取出來的torch.FloatTensor轉換為numpy
np_image = tensor_image.numpy()
np_image = np.transpose(np_image, [1, 2, 0])
plt.show()
方式二
利用torchvision中的功能函數,一般用於批量顯示圖片。
img=torchvision.utils.make_grid(img).numpy() plt.imshow(np.transpose(img,(1,2,0))) plt.show()