Pytorch:以單通道(灰度圖)加載圖片


以單通道(灰度圖)加載圖片

如果我們想以單通道加載圖片,設置加載數據集時的transform參數如下即可:

from torchvision import datasets, transforms
transform = transforms.Compose(
    [

        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor()
    ]
) 
data = datasets.CIFAR10(root=".", download=True,transform=transform)
print(type(data[0][0])) # <class 'torch.Tensor'>
print(data[0][0].shape) # torch.Size([1, 32, 32])
print(data[0][0])
# tensor([[[0.2392, 0.1765, 0.1882,  ..., 0.5373, 0.5098, 0.5059],
#          [0.0745, 0.0000, 0.0392,  ..., 0.3725, 0.3529, 0.3686],
#          [0.0941, 0.0353, 0.1216,  ..., 0.3529, 0.3569, 0.3137],
#          ...,
#          [0.6784, 0.6039, 0.6157,  ..., 0.5255, 0.1412, 0.1490],
#          [0.5725, 0.5059, 0.5647,  ..., 0.6000, 0.2706, 0.2353],
#          [0.5922, 0.5373, 0.5765,  ..., 0.7412, 0.4863, 0.3882]]])

可以看到我們得到了歸一化后的單通道torch.Tensor對象。

PS:torch.Tensor對象可以以torch.tensor(...)torch.Tensor(...)兩種方法初始化得到的,具體區別在於torch.Tensor(...)可接受多個參數,其參數表示Tensor各個維度的大小,比如torch.Tensor會返回一個為已初始化的存有10個數(類型為torch.float32)的Tensor對象,而torch.tensor(10)只能接受一個參數,該參數表示初始化的數據,比如torch.tensor(10)會返回一個包含單個值10(類型為torch.int64)的Tensor對象:

import torch

a = torch.Tensor(10)
print(a)
# tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 3.6734e-40, 0.0000e+00, 2.0000e+00,
#         0.0000e+00, 2.0000e+00, 7.3787e+2])
print(type(a)) # <class 'torch.Tensor'>
print(a.dtype) # torch.float32

b = torch.tensor(10)
print(b) # tensor(10)
print(type(b)) # <class 'torch.Tensor'>
print(b.dtype) # torch.int64

a = torch.Tensor(2, 3)
print(a)
# tensor([[1.6217e-19, 7.0062e+22, 6.3828e+28],
#         [3.8016e-39, 0.0000e+00, 2.0000e+00]])

b = torch.tensor([2, 3])
print(b) # tensor([2, 3])
b = torch.tensor((2, 3))
print(b) # tensor([2, 3])

詳情可參見Pytorch討論區帖子:Difference between torch.tensor() and torch.Tensor()[1]

這里再多說一點,這里的transforms.ToTensor()接收PIL格式的數據, 或者是直接從PIL轉來的np.ndarray格式數據, 只要保證進來的數據取值范圍是[0, 255], 格式是HWC(H、W、C分別對應圖片高度、寬度、通道數,這也就是我們在日常生活中存儲圖片的常用順序), 像素順序是RGB, 它就會幫我們完成下列的工作:

  • 取值范圍[0, 255] / 255.0 => [0, 1.0], 數據格式從uint8變成了torch.float32
  • 形狀(shape)轉為CHW,但像素順序依舊是RGB。

比如如果不加transforms.ToTensor(),就會直接得到PIL格式的圖片:

from torchvision import datasets, transforms
import numpy as np

transform = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=1),
    ]
) 
data = datasets.CIFAR10(root=".", download=True,transform=transform)

img = data[0][0]
print(type(img))  # <class 'PIL.Image.Image'>

然后我們可以嘗試先將PIL.Image.Image對象轉為np.ndarray,然后再轉為torch.Tensor類型的對象:

np_img = np.asarray(img)
print(np_img.dtype) # uint8
tensor_from_np = transforms.ToTensor()(np_img)
print(type(tensor_from_np)) # <class 'torch.Tensor'>
print(tensor_from_np.dtype) # torch.float32
print(tensor_from_np.shape) # torch.Size([1, 32, 32])
print(tensor_from_np)
# tensor([[[0.2392, 0.1765, 0.1882,  ..., 0.5373, 0.5098, 0.5059],
#          [0.0745, 0.0000, 0.0392,  ..., 0.3725, 0.3529, 0.3686],
#          [0.0941, 0.0353, 0.1216,  ..., 0.3529, 0.3569, 0.3137],
#          ...,
#          [0.6784, 0.6039, 0.6157,  ..., 0.5255, 0.1412, 0.1490],
#          [0.5725, 0.5059, 0.5647,  ..., 0.6000, 0.2706, 0.2353],
#          [0.5922, 0.5373, 0.5765,  ..., 0.7412, 0.4863, 0.3882]]])

PS: 最后再提一下Tensorflow,Tensorflow雖然調用的tf.keras.datasets.cifar10.load_data()能直接得到類型為numpy.ndarray並按照HWC順序存儲的數據,但是需要手動去添加/255以對數據歸一化,如下所示:

import tensorflow as tf
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
print(type(x_train)) # <class 'numpy.ndarray'>
print(x_train.shape) # (50000, 32, 32, 3)
print(x_train)
# [[[[ 59  62  63]
#    [ 43  46  45]
#    [ 50  48  43]
#    ...
#    [179 177 173]
#    [164 164 162]
#    [163 163 161]]]]
x_train = x_train.astype(np.float32) / 255.0
print(x_train)
# [[[[0.23137255 0.24313726 0.24705882]
#    [0.16862746 0.18039216 0.1764706 ]
#    [0.19607843 0.1882353  0.16862746]
#    ...
#    [0.7019608  0.69411767 0.6784314 ]
#    [0.6431373  0.6431373  0.63529414]
#    [0.6392157  0.6392157  0.6313726 ]]]]

參考


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM